diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 710bf8f..ef6c605 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -13,17 +13,41 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
+ - name: Set environment variables
+ run: |
+ echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: actions/cache@v4
with:
path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('fme/constraints.txt') }}
+ key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }}
- name: Install dependencies
run: |
- python -m pip install uv==0.2.5
- uv pip install --system -c constraints.txt -e fme[dev]
+ python -m pip install uv
+ uv pip install --system -c constraints.txt -e ./fme[dev]
- name: Run pytest
run: |
make test
+ cpu-very-fast:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set environment variables
+ run: |
+ echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - uses: actions/cache@v4
+ with:
+ path: ${{ env.pythonLocation }}
+ key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }}
+ - name: Install dependencies
+ run: |
+ python -m pip install uv
+ uv pip install --system -c constraints.txt -e ./fme[dev]
+ - name: Run pytest
+ run: |
+ make test_very_fast
diff --git a/.gitignore b/.gitignore
index a155547..38980e2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -136,3 +136,6 @@ dmypy.json
# scratch directory for testing
scratch/
+
+# Some in progress data pipelines get added here
+scripts/data_process/.nfs*
diff --git a/.isort.cfg b/.isort.cfg
deleted file mode 100644
index b9fb3f3..0000000
--- a/.isort.cfg
+++ /dev/null
@@ -1,2 +0,0 @@
-[settings]
-profile=black
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8e01376..55bdf31 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,41 +1,24 @@
repos:
-- repo: https://github.com/psf/black
- rev: 23.3.0
- hooks:
- - id: black
- additional_dependencies: ["click==8.0.4"]
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.8.1
+ hooks:
+ - id: ruff
+ args: ["--fix", "--config", "fme/pyproject.toml"]
+ - id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v2.3.0
+ rev: v5.0.0
hooks:
- id: check-added-large-files
args: [--maxkb=250]
- id: trailing-whitespace
- - id: flake8
- name: flake8
- language_version: python3
- exclude: "__init__.py"
- args: [--config, setup.cfg]
- - id: flake8
- name: flake8 __init__.py files
- files: "__init__.py"
- # ignore unused import error in __init__.py files
- args: ["--ignore=F401,E203,W503", --config, setup.cfg]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.2.0
hooks:
- id: mypy
additional_dependencies: ["types-PyYaml==5.4.3"]
- args: [
- --follow-imports, silent, --ignore-missing-imports
- ]
+ args: ["--ignore-missing-imports", "--check-untyped-defs"]
exclude: |
(?x)^(
.+/conf.py |
- .+/setup.py |
.+/conftest.py
- )$
-- repo: https://github.com/pycqa/isort
- rev: 5.11.5
- hooks:
- - id: isort
- name: isort (python)
\ No newline at end of file
+ )$
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 0ff16ed..adf7a36 100644
--- a/Makefile
+++ b/Makefile
@@ -4,7 +4,7 @@ ENVIRONMENT_NAME ?= fme
DEPLOY_TARGET ?= pypi
build_docker_image:
- docker build -f docker/Dockerfile -t $(IMAGE):$(VERSION) .
+ docker build --platform=linux/amd64 -f docker/Dockerfile -t $(IMAGE):$(VERSION) .
enter_docker_image: build_docker_image
docker run -it --rm $(IMAGE):$(VERSION) bash
@@ -12,11 +12,18 @@ enter_docker_image: build_docker_image
# recommended to deactivate current conda environment before running this
create_environment:
conda create -n $(ENVIRONMENT_NAME) python=3.10 pip
- conda run --no-capture-output -n $(ENVIRONMENT_NAME) python -m pip install uv==0.2.5
- conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -c constraints.txt -e fme[dev]
+ conda run --no-capture-output -n $(ENVIRONMENT_NAME) python -m pip install uv
+ conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -c constraints.txt -e ./fme[dev,docs]
+ conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -r analysis-deps.txt
test:
- pytest --durations 20 .
+ pytest --durations 40 .
+
+test_fast:
+ pytest --durations 40 --fast .
+
+test_very_fast:
+ pytest --durations 40 --very-fast .
# For maintainer use only
# requires fme[deploy] to be installed
diff --git a/README.md b/README.md
index f6f944b..2efacac 100644
--- a/README.md
+++ b/README.md
@@ -43,4 +43,4 @@ gs://ai2cm-public-requester-pays/2024-11-13-ai2-climate-emulator-v2-amip/data/er
The dataset used in the [ACE2-SOM paper](https://arxiv.org/abs/2412.04418) is available at:
```
gs://ai2cm-public-requester-pays/2024-12-05-ai2-climate-emulator-v2-som/SHiELD-SOM-C96
-```
\ No newline at end of file
+```
diff --git a/analysis-deps.txt b/analysis-deps.txt
new file mode 100644
index 0000000..3edb535
--- /dev/null
+++ b/analysis-deps.txt
@@ -0,0 +1,13 @@
+# these are some packages which are convenient to have installed for ad-hoc analysis
+# but which are not requirements of the "fme" package. We do not relist the fme
+# dependencies here.
+beaker-py
+Bottleneck
+cartopy>=0.22.0
+dask[distributed]
+ipywidgets
+nc-time-axis
+jupyterlab
+pyproj<3.7
+seaborn
+bokeh>=3.1.0
\ No newline at end of file
diff --git a/conftest.py b/conftest.py
new file mode 100644
index 0000000..265e2a4
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,68 @@
+import signal
+
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--fast",
+ action="store_true",
+ default=False,
+ help="Skip slow tests",
+ )
+ parser.addoption(
+ "--very-fast",
+ action="store_true",
+ default=False,
+ help="Run only very fast tests (< 5 seconds)",
+ )
+
+
+@pytest.fixture
+def skip_slow(request, very_fast_only):
+ return very_fast_only or request.config.getoption("--fast")
+
+
+@pytest.fixture
+def very_fast_only(request):
+ return request.config.getoption("--very-fast")
+
+
+class TimeoutException(Exception):
+ pass
+
+
+def timeout_handler(signum, frame):
+ raise TimeoutException("Test took too long")
+
+
+@pytest.fixture
+def pdb_enabled(request):
+ return request.config.getoption("--pdb")
+
+
+@pytest.fixture(autouse=True, scope="function")
+def enforce_timeout(skip_slow, very_fast_only, pdb_enabled):
+ if pdb_enabled:
+ yield # Do not enforce timeout if we are debugging
+ return
+ if very_fast_only:
+ timeout_seconds = 3
+ elif skip_slow:
+ timeout_seconds = 30
+ else:
+ timeout_seconds = 60
+ signal.signal(signal.SIGALRM, timeout_handler)
+ signal.alarm(timeout_seconds) # Set the timeout for the test
+ try:
+ yield
+ finally:
+ signal.alarm(0) # Disable the alarm after the test completes
+
+
+@pytest.hookimpl(tryfirst=True, hookwrapper=True)
+def pytest_runtest_call(item):
+ try:
+ yield
+ except TimeoutException:
+ pytest.fail("Test failed due to timeout")
diff --git a/fme/conftest.py b/fme/conftest.py
deleted file mode 100644
index 475fb2d..0000000
--- a/fme/conftest.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import pytest
-
-
-def pytest_addoption(parser):
- parser.addoption(
- "--fast", action="store_true", default=False, help="Run only fast tests"
- )
-
-
-@pytest.fixture
-def skip_slow(request):
- return request.config.getoption("--fast")
diff --git a/fme/docs/_static/Ai2_icon_pink_RGB.png b/fme/docs/_static/Ai2_icon_pink_RGB.png
new file mode 100644
index 0000000..e418a9d
Binary files /dev/null and b/fme/docs/_static/Ai2_icon_pink_RGB.png differ
diff --git a/fme/docs/_static/Ai2_icon_pink_padding_RGB.png b/fme/docs/_static/Ai2_icon_pink_padding_RGB.png
new file mode 100644
index 0000000..d84ccef
Binary files /dev/null and b/fme/docs/_static/Ai2_icon_pink_padding_RGB.png differ
diff --git a/fme/docs/_static/custom.css b/fme/docs/_static/custom.css
new file mode 100644
index 0000000..14354dc
--- /dev/null
+++ b/fme/docs/_static/custom.css
@@ -0,0 +1,21 @@
+body[data-theme="dark"] {
+ --code-block-background: #202020;
+}
+
+body[data-theme="light"] {
+ --code-block-background: #f8f9fb;
+}
+
+body[data-theme="auto"] {
+ --code-block-background: #f8f9fb;
+}
+
+@media (prefers-color-scheme: dark) {
+ body[data-theme="auto"] {
+ --code-block-background: #202020;
+ }
+}
+
+div.highlight pre {
+ background: var(--code-block-background);
+}
\ No newline at end of file
diff --git a/fme/docs/builder.rst b/fme/docs/builder.rst
index 5a23d31..e32f7c0 100644
--- a/fme/docs/builder.rst
+++ b/fme/docs/builder.rst
@@ -51,7 +51,7 @@ Let's define a training configuration ``TrainConfig`` containing an ``OptimizerC
"""
Configuration for an optimizer.
- Attributes:
+ Parameters:
optimizer_type: The type of optimizer to use.
lr: The learning rate.
kwargs: Additional keyword arguments to pass to the optimizer.
diff --git a/fme/docs/conf.py b/fme/docs/conf.py
index eeae207..20f9794 100755
--- a/fme/docs/conf.py
+++ b/fme/docs/conf.py
@@ -24,7 +24,7 @@
sys.path.insert(0, os.path.abspath(".."))
import fme # noqa
-import fme.ace
+import fme.core.registry
# -- General configuration ---------------------------------------------
@@ -33,7 +33,7 @@
# needs_sphinx = '1.0'
# Fetch the dynamic data
-module_types = fme.ace.get_available_module_types()
+module_types = fme.core.registry.ModuleSelector.get_available_types()
# Create a dynamic rst snippet that can be included in your documentation
rst_snippet = f".. code-block:: text\n\n {module_types}"
@@ -50,7 +50,15 @@
"sphinx.ext.autodoc",
"sphinx.ext.viewcode",
"sphinx.ext.doctest",
+ "sphinx.ext.intersphinx",
+ "sphinx_autodoc_typehints",
]
+autodoc_typehints = "description"
+
+# Intersphinx configuration to link to other projects' documentation
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3", None),
+}
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -93,6 +101,9 @@
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
+# Include default values when documenting parameter types.
+typehints_defaults = "comma"
+
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
@@ -102,18 +113,37 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = "sphinx_rtd_theme"
+html_theme = "furo"
+html_title = f"fme v{fme.__version__}"
# Theme options are theme-specific and customize the look and feel of a
# theme further. For a list of options available for each theme, see the
# documentation.
#
-# html_theme_options = {}
+html_theme_options = {
+ "light_logo": "Ai2_icon_pink_padding_RGB.png",
+ "dark_logo": "Ai2_icon_pink_padding_RGB.png",
+ "footer_icons": [
+ {
+ "name": "GitHub",
+ "url": "https://github.com/ai2cm/ace",
+ "html": """
+
+ """, # noqa: E501
+ "class": "",
+ },
+ ],
+}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = []
+html_static_path = ["_static"]
+html_css_files = ["custom.css"]
+
+html_favicon = "_static/Ai2_icon_pink_RGB.png"
# -- Options for HTMLHelp output ---------------------------------------
@@ -176,3 +206,9 @@
"Miscellaneous",
),
]
+
+
+# -- Options for doctest -----------------------------------------------
+doctest_global_setup = """
+import fme
+"""
diff --git a/fme/docs/configs/explicit-indices.yaml b/fme/docs/configs/explicit-indices.yaml
new file mode 100644
index 0000000..3080a5e
--- /dev/null
+++ b/fme/docs/configs/explicit-indices.yaml
@@ -0,0 +1,3 @@
+path: initial_conditions.nc
+start_indices:
+ list: [0, 3, 7]
diff --git a/fme/docs/configs/inference-ic-indices.yaml b/fme/docs/configs/inference-ic-indices.yaml
new file mode 100644
index 0000000..8c30681
--- /dev/null
+++ b/fme/docs/configs/inference-ic-indices.yaml
@@ -0,0 +1,5 @@
+path: initial_conditions.nc
+start_indices:
+ n_initial_conditions: 3
+ first: 1
+ interval: 2
diff --git a/fme/docs/configs/timestamp-list.yaml b/fme/docs/configs/timestamp-list.yaml
new file mode 100644
index 0000000..729f315
--- /dev/null
+++ b/fme/docs/configs/timestamp-list.yaml
@@ -0,0 +1,5 @@
+path: initial_conditions.nc
+start_indices:
+ times:
+ - "2021-01-01T00:00:00"
+ - "2021-02-01T00:00:00"
diff --git a/fme/docs/index.rst b/fme/docs/index.rst
index 6657d3b..df4c0af 100644
--- a/fme/docs/index.rst
+++ b/fme/docs/index.rst
@@ -1,6 +1,9 @@
-Welcome to ACE's documentation!
+fme: Full Model Emulation
======================================
+**fme** ("full model emulation") is a python package for training and running
+climate model emulators, such as the Ai2 Climate Emulator.
+
.. toctree::
:maxdepth: 1
:caption: Contents:
diff --git a/fme/docs/inference-config.yaml b/fme/docs/inference-config.yaml
index 277ec47..e616530 100644
--- a/fme/docs/inference-config.yaml
+++ b/fme/docs/inference-config.yaml
@@ -1,19 +1,25 @@
experiment_dir: inference_output
-n_forward_steps: 400 # 100 days
+n_forward_steps: 400 # 100 days
forward_steps_in_memory: 50
-checkpoint_path: ckpt.tar
+checkpoint_path: ace_ckpt.tar
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
project: ace
- entity: your_wandb_entity
initial_condition:
- path: initial_condition/data.nc
+ path: climSST/ic_2021.zarr
+ start_indices:
+ n_initial_conditions: 2
+ first: 0
+ interval: 3
+ engine: zarr
forcing_loader:
dataset:
- data_path: forcing
- num_data_workers: 8
+ data_path: climSST
+ file_pattern: forcing_2021.zarr
+ engine: zarr
+ n_repeats: 2 # use this to extend the 1-year of forcing data to desired length
+ num_data_workers: 2
data_writer:
save_prediction_files: false
- save_monthly_files: false
diff --git a/fme/docs/inference_config.rst b/fme/docs/inference_config.rst
index dc8612d..5f1a98d 100644
--- a/fme/docs/inference_config.rst
+++ b/fme/docs/inference_config.rst
@@ -11,25 +11,30 @@ The example assumes you are running in a directory structure like:
::
.
- ├── ckpt.tar
- ├── initial_condition
- │ └── data.nc # name must reflect the path in the config
- └── forcing
- ├── data1.nc # files can have any name, but must sort into time-sequential order
- ├── data2.nc # can have any number of netCDF files
- └── ...
+ ├── ace_ckpt.tar
+ ├── climSST
+ │ ├── forcing_2021.zarr
+ │ ├── ic_2021-01-01.zarr
+ │ └── ic_2021.zarr
+ └── inference-config.yaml
-The ``.nc`` files correspond to data files like ``2021010100.nc`` in the `Zenodo repository`_, while ``ckpt.tar`` corresponds to a file like ``ace_ckpt.tar`` in that repository.
+that includes a model checkpoint (``ace_ckpt.tar``), forcing data (``forcing_2021.zarr``), and an initial condition (e.g., ``ic_2021-01-01.zarr``). You can find the forcing and initial condition data in the `Zenodo repository`_.
The specified initial condition file should contain a time dimension of at least length 1, but can also
-contain multiple times. If multiple times are present and `start_indices` is not specified in the
-configuration, the inference will run an ensemble using all times in the initial condition file.
-Selections from initial conditions can be made using the `start_indices` parameter in the configuration.
+contain multiple times. If multiple times are present and ``start_indices`` is not specified in the
+:class:`fme.ace.InitialConditionConfig` configuration, the inference will run an ensemble using all times
+in the initial condition file. The ``ic_2021.zarr`` file is an example of a file with multiple times, containing
+initial conditions for each month of 2021. For examples of selecting specific initial
+conditions, see :ref:`initial-condition-examples`.
-While netCDF files are specified in the example, Zarr datasets are also compatible. E.g.,
-specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset configuration.
+While Zarr files are specified in the example, netCDFs are also compatible. E.g.,
+specifying the parent folder with the netCDF files as the ``path`` setting ``engine`` to ``netcdf4``
+in the dataset configuration. See :class:`fme.ace.XarrayDataConfig` for an example.
-.. _Zenodo repository: https://zenodo.org/doi/10.5281/zenodo.10791086
+Example YAML Configuration
+---------------------------
+
+.. _Zenodo repository: https://zenodo.org/records/13787710
.. literalinclude:: inference-config.yaml
:language: yaml
@@ -52,9 +57,9 @@ specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset
)
# these paths are used in the documentation on this page
# if they change then update the docs!
- assert config.checkpoint_path == "ckpt.tar"
- assert config.initial_condition.path == "initial_condition/data.nc"
- assert config.forcing_loader.dataset.data_path == "forcing"
+ assert config.checkpoint_path == "ace_ckpt.tar"
+ assert config.initial_condition.path == "climSST/ic_2021.zarr"
+ assert config.forcing_loader.dataset.data_path == "climSST"
print("Loaded successfully")
.. testoutput::
@@ -62,6 +67,9 @@ specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset
Loaded successfully
+Configuration structure
+-----------------------
+
We use the :ref:`Builder pattern ` to load this configuration into a multi-level dataclass structure.
The configuration is divided into several sub-configurations, each with its own dataclass.
The top-level configuration is the :class:`fme.ace.InferenceConfig` class.
@@ -99,3 +107,89 @@ The sub-configurations are:
.. autoclass:: fme.ace.OceanConfig
:show-inheritance:
:noindex:
+
+
+ .. _initial-condition-examples:
+
+:class:`fme.ace.InitialConditionConfig` Examples
+-------------------------------------------------
+
+The ``start_indices`` attribute can be used to specify which initial conditions
+to use when multiple are present in the dataset (instead of using all available).
+The following examples show example selections using the yaml builder pattern for
+an ``InitialConditionConfig``.
+
+
+:class:`fme.ace.InferenceInitialConditionIndices`
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Select a number of regularly spaced initial conditions.
+
+.. literalinclude:: configs/inference-ic-indices.yaml
+ :language: yaml
+
+.. testcode::
+ :hide:
+
+ from fme.ace import InitialConditionConfig
+ import dacite
+ import yaml
+
+ with open('configs/inference-ic-indices.yaml', 'r') as f:
+ config_dict = yaml.safe_load(f)
+
+ config = dacite.from_dict(
+ InitialConditionConfig,
+ data=config_dict,
+ config=dacite.Config(strict=True)
+ )
+ assert config.start_indices.n_initial_conditions == 3
+
+:class:`fme.ace.TimestampList`
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Selecting two timestamps from the initial conditions.
+
+.. literalinclude:: configs/timestamp-list.yaml
+ :language: yaml
+
+.. testcode::
+ :hide:
+
+ from fme.ace import InitialConditionConfig
+ import dacite
+ import yaml
+
+ with open('configs/timestamp-list.yaml', 'r') as f:
+ config_dict = yaml.safe_load(f)
+
+ config = dacite.from_dict(
+ InitialConditionConfig,
+ data=config_dict,
+ config=dacite.Config(strict=True)
+ )
+ assert config.start_indices.times[0] == '2021-01-01T00:00:00'
+
+:class:`fme.ace.ExplicitIndices`
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Selecting specific indices from the initial conditions.
+
+.. literalinclude:: configs/explicit-indices.yaml
+ :language: yaml
+
+.. testcode::
+ :hide:
+
+ from fme.ace import InitialConditionConfig
+ import dacite
+ import yaml
+
+ with open('configs/explicit-indices.yaml', 'r') as f:
+ config_dict = yaml.safe_load(f)
+
+ config = dacite.from_dict(
+ InitialConditionConfig,
+ data=config_dict,
+ config=dacite.Config(strict=True)
+ )
+ assert config.start_indices.list[1] == 3
+
+
diff --git a/fme/docs/installation.rst b/fme/docs/installation.rst
index eb5b27e..335f58e 100644
--- a/fme/docs/installation.rst
+++ b/fme/docs/installation.rst
@@ -34,7 +34,7 @@ A make target is available to build a conda environment:
make create_environment
-This will create an environment named ``fme``, and should use the same package versions we have used in development. If you would like a different name, set the ENVIRONMENT_NAME variable:
+This will create an environment named ``fme``. If you would like a different name, set the ENVIRONMENT_NAME variable:
.. code-block:: shell
diff --git a/fme/docs/modules.rst b/fme/docs/modules.rst
index a3dd14e..a16b404 100644
--- a/fme/docs/modules.rst
+++ b/fme/docs/modules.rst
@@ -15,7 +15,7 @@ The following module types are available:
.. include:: available_modules.rst
-.. autofunction:: fme.ace.get_available_module_types
+.. autofunction:: fme.core.registry.ModuleSelector.get_available_types
The following module builders are available:
@@ -30,3 +30,9 @@ The following module builders are available:
:undoc-members:
:show-inheritance:
:noindex:
+
+.. autoclass:: fme.ace.HEALPixRecUNetBuilder
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :noindex:
\ No newline at end of file
diff --git a/fme/docs/quickstart.rst b/fme/docs/quickstart.rst
index 1488aca..64f1ba7 100644
--- a/fme/docs/quickstart.rst
+++ b/fme/docs/quickstart.rst
@@ -20,7 +20,7 @@ For the optional Weights and Biases (wandb) integration, you will need to set th
export WANDB_API_KEY=wandb-api-key
-where `wandb-api-key` is created and retrieved from the "API Keys" section of the `Wandb`_ settings page.
+where ``wandb-api-key`` is created and retrieved from the "API Keys" section of the `Wandb`_ settings page.
.. _Wandb: https://wandb.ai/settings
@@ -58,7 +58,7 @@ For example, the 10-year validation data (approx. 190GB) can be downloaded with:
gsutil -m -u YOUR_GCP_PROJECT cp -r gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/data/repeating-climSST-1deg-netCDFs/validation .
-It is possible to download a portion of the dataset only, but it is necessary to have enough data to span the desired prediction period. The checkpoint is also available on GCS at `gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/checkpoints/ace_ckpt.tar`.
+It is possible to download a portion of the dataset only, but it is necessary to have enough data to span the desired prediction period. The checkpoint is also available on GCS at ``gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/checkpoints/ace_ckpt.tar``.
.. _Zenodo repository: https://zenodo.org/doi/10.5281/zenodo.10791086
.. _requester pays: https://cloud.google.com/storage/docs/requester-pays
@@ -83,7 +83,8 @@ If you run into configuration issues, you can validate your configuration with
While inference can be performed without a GPU, it may be very slow. If running on a Mac, set the environmental variable
``export FME_USE_MPS=1`` to enable using the `Metal Performance Shaders`_ framework for GPU acceleration. Note this backend is
- not fully featured and it may not work with all inference features or for training.
+ not fully featured and it may not work with all inference features or for training. It is recommended to use the latest version
+ of torch if using MPS.
.. _Metal Performance Shaders: https://developer.apple.com/metal/pytorch/
@@ -171,7 +172,7 @@ Then in the ``fme`` conda environment, run evaluation with:
torchrun --nproc_per_node RANK_COUNT -m fme.ace.train config-train.yaml
-where RANK_COUNT is how many processors you want to run on.
+where ``RANK_COUNT`` is how many processors you want to run on.
This will typically be the number of GPUs you have available.
If running on a single GPU, you can omit the `torchrun` command and use ``python -m`` instead.
diff --git a/fme/docs/requirements.txt b/fme/docs/requirements.txt
index 6deb546..41b311d 100644
--- a/fme/docs/requirements.txt
+++ b/fme/docs/requirements.txt
@@ -1,2 +1,3 @@
+furo==2024.04.27
sphinx==7.0.0
-sphinx-rtd-theme==2.0.0
+sphinx_autodoc_typehints
\ No newline at end of file
diff --git a/fme/fme/ace/__init__.py b/fme/fme/ace/__init__.py
index 8d30ba0..27be7f1 100644
--- a/fme/fme/ace/__init__.py
+++ b/fme/fme/ace/__init__.py
@@ -1,5 +1,16 @@
import sys
+from fme.ace.data_loading.inference import (
+ ExplicitIndices,
+ InferenceInitialConditionIndices,
+ TimestampList,
+)
+from fme.ace.data_loading.perturbation import (
+ ConstantConfig,
+ GreensFunctionConfig,
+ PerturbationSelector,
+ SSTPerturbation,
+)
from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig
from fme.ace.inference.evaluator import (
DataWriterConfig,
@@ -16,22 +27,31 @@
InitialConditionConfig,
run_inference_from_config,
)
-from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder
-from fme.core.corrector import CorrectorConfig
-from fme.core.data_loading.config import TimeSlice, XarrayDataConfig
-from fme.core.data_loading.inference import (
- ExplicitIndices,
- InferenceInitialConditionIndices,
- TimestampList,
+from fme.ace.models.healpix.healpix_activations import (
+ CappedGELUConfig,
+ DownsamplingBlockConfig,
)
+from fme.ace.models.healpix.healpix_blocks import ConvBlockConfig, RecurrentBlockConfig
+from fme.ace.registry.hpx import (
+ HEALPixRecUNetBuilder,
+ UNetDecoderConfig,
+ UNetEncoderConfig,
+)
+from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder
+from fme.core.corrector.corrector import CorrectorConfig
+from fme.core.corrector.ocean import OceanCorrectorConfig
+from fme.core.dataset.config import TimeSlice, XarrayDataConfig
+from fme.core.gridded_ops import GriddedOperations
from fme.core.loss import WeightedMappingLossConfig
-from fme.core.normalizer import FromStateNormalizer, NormalizationConfig
+from fme.core.normalizer import NormalizationConfig
from fme.core.ocean import SlabOceanConfig
from fme.core.optimization import SchedulerConfig
from fme.core.parameter_init import FrozenParameterConfig, ParameterInitializationConfig
-from fme.core.registry import ModuleSelector, get_available_module_types, register
+from fme.core.registry.corrector import CorrectorSelector
+from fme.core.registry.module import ModuleSelector
+from fme.core.typing_ import Slice
-from .train.train import run_train_from_config
+from .train.train import run_train
from .train.train_config import (
CopyWeightsConfig,
DataLoaderConfig,
@@ -41,7 +61,6 @@
LoggingConfig,
OptimizationConfig,
SingleModuleStepperConfig,
- Slice,
TrainConfig,
)
diff --git a/fme/fme/core/aggregator/__init__.py b/fme/fme/ace/aggregator/__init__.py
similarity index 100%
rename from fme/fme/core/aggregator/__init__.py
rename to fme/fme/ace/aggregator/__init__.py
diff --git a/fme/fme/core/aggregator/inference/__init__.py b/fme/fme/ace/aggregator/inference/__init__.py
similarity index 100%
rename from fme/fme/core/aggregator/inference/__init__.py
rename to fme/fme/ace/aggregator/inference/__init__.py
diff --git a/fme/fme/core/aggregator/inference/annual.py b/fme/fme/ace/aggregator/inference/annual.py
similarity index 89%
rename from fme/fme/core/aggregator/inference/annual.py
rename to fme/fme/ace/aggregator/inference/annual.py
index d3c736c..7e51dfb 100644
--- a/fme/fme/core/aggregator/inference/annual.py
+++ b/fme/fme/ace/aggregator/inference/annual.py
@@ -1,30 +1,30 @@
import dataclasses
import datetime
-from typing import Any, Dict, List, Mapping, Optional, Tuple
+from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
import numpy as np
import torch
import xarray as xr
from matplotlib.figure import Figure
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
-from fme.core.metrics import weighted_mean
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
class GlobalMeanAnnualAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ ops: GriddedOperations,
timestep: datetime.timedelta,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
monthly_reference_data: Optional[xr.Dataset] = None,
):
- self.area_weights = area_weights
+ self._area_weighted_mean = ops.area_weighted_mean
self.timestep = timestep
- self.metadata = metadata
+ self.variable_metadata = variable_metadata
self._target_datasets: Optional[List[xr.Dataset]] = None
self._gen_datasets: Optional[List[xr.Dataset]] = None
self._monthly_reference_data = monthly_reference_data
@@ -37,7 +37,7 @@ def _get_reference(self, name: str) -> Optional["VariableReferenceData"]:
if name not in self._monthly_reference_data:
return None
self._variable_reference_data[name] = process_monthly_reference(
- self._monthly_reference_data, self.area_weights, name
+ self._monthly_reference_data, self._area_weighted_mean, name
)
return self._variable_reference_data[name]
@@ -51,12 +51,10 @@ def record_batch(
"""Record a batch of data for computing time variability statistics."""
target_data_area_mean, gen_data_area_mean = {}, {}
for name in gen_data.keys():
- target_data_area_mean[name] = weighted_mean(
- target_data[name], self.area_weights, dim=(-1, -2)
- ).cpu()
- gen_data_area_mean[name] = weighted_mean(
- gen_data[name], self.area_weights, dim=(-1, -2)
+ target_data_area_mean[name] = self._area_weighted_mean(
+ target_data[name]
).cpu()
+ gen_data_area_mean[name] = self._area_weighted_mean(gen_data[name]).cpu()
target_ds = to_dataset(target_data_area_mean, time)
gen_ds = to_dataset(gen_data_area_mean, time)
@@ -190,6 +188,19 @@ def get_logs(self, label: str) -> Dict[str, Any]:
logs.update({f"{label}{name}": metrics[name] for name in metrics.keys()})
return logs
+ def get_dataset(self) -> xr.Dataset:
+ gathered = self._get_gathered_means()
+ if gathered is None:
+ return xr.Dataset()
+ target, gen = gathered
+ return xr.concat(
+ [
+ target.expand_dims({"source": ["target"]}),
+ gen.expand_dims({"source": ["prediction"]}),
+ ],
+ dim="source",
+ )
+
@dataclasses.dataclass
class VariableReferenceData:
@@ -199,14 +210,12 @@ class VariableReferenceData:
def process_monthly_reference(
- monthly_reference_data: xr.Dataset, area_weights: torch.Tensor, name: str
+ monthly_reference_data: xr.Dataset,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ name: str,
) -> VariableReferenceData:
ref_global_mean = xr.DataArray(
- weighted_mean(
- torch.as_tensor(monthly_reference_data[name].values),
- weights=area_weights.cpu(),
- dim=(-1, -2),
- ),
+ area_weighted_mean(torch.as_tensor(monthly_reference_data[name].values)),
dims=monthly_reference_data[name].dims[:-2],
coords={"time": monthly_reference_data[name].coords["time"]},
)
diff --git a/fme/fme/core/aggregator/inference/enso/__init__.py b/fme/fme/ace/aggregator/inference/enso/__init__.py
similarity index 100%
rename from fme/fme/core/aggregator/inference/enso/__init__.py
rename to fme/fme/ace/aggregator/inference/enso/__init__.py
diff --git a/fme/fme/core/aggregator/inference/enso/enso.py b/fme/fme/ace/aggregator/inference/enso/enso.py
similarity index 73%
rename from fme/fme/core/aggregator/inference/enso/enso.py
rename to fme/fme/ace/aggregator/inference/enso/enso.py
index cf14f04..440841c 100644
--- a/fme/fme/core/aggregator/inference/enso/enso.py
+++ b/fme/fme/ace/aggregator/inference/enso/enso.py
@@ -3,14 +3,15 @@
import cftime
import matplotlib.pyplot as plt
+import numpy as np
import torch
import xarray as xr
-from fme.core.aggregator.plotting import get_cmap_limits, plot_imshow, plot_paneled_data
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.aggregator.plotting import get_cmap_limits, plot_imshow, plot_paneled_data
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
-from fme.core.metrics import root_mean_squared_error
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import WandB
@@ -26,6 +27,7 @@ def index_data_array(
Args:
index_data: List of (time, index) tuples.
+ calendar: Calendar for the time coordinate.
Returns:
ENSO index data as an xarray DataArray.
@@ -73,31 +75,31 @@ class EnsoCoefficientEvaluatorAggregator:
data are time-aggregated.
Args:
- initial_times: Initial times for each sample.
+ initial_time: Initial time for each sample.
n_forward_timesteps: Number of timesteps for each sample.
timestep: Timestep duration.
- area_weights: Area weights for spatial averaging.
- metadata: Metadata for the variables in the data.
+ gridded_operations: GriddedOperations instance for area-weighted RMSE.
+ variable_metadata: Metadata for the variables in the data.
"""
def __init__(
self,
- initial_times: xr.DataArray,
+ initial_time: xr.DataArray,
n_forward_timesteps: int,
timestep: datetime.timedelta,
- area_weights: torch.Tensor,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ gridded_operations: GriddedOperations,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
- self._sample_index_series: List[
- Optional[xr.DataArray]
- ] = get_sample_index_series(
- self.enso_index, initial_times, n_forward_timesteps, timestep
+ self._sample_index_series: List[Optional[xr.DataArray]] = (
+ get_sample_index_series(
+ self.enso_index, initial_time, n_forward_timesteps, timestep
+ )
)
- self._area_weights: torch.Tensor = area_weights
- if metadata is not None:
- self._metadata: Mapping[str, VariableMetadata] = metadata
+ self._ops = gridded_operations
+ if variable_metadata is not None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = variable_metadata
else:
- self._metadata = {}
+ self._variable_metadata = {}
n_samples = len(self._sample_index_series)
self._target_covariances: List[TensorDict] = [{} for _ in range(n_samples)]
self._gen_covariances: List[TensorDict] = [{} for _ in range(n_samples)]
@@ -129,38 +131,36 @@ def record_batch(
), "number of index series must match number of samples"
for i_sample, sample_index_series in enumerate(self._sample_index_series):
if sample_index_series is not None:
- sample_index_series_reindexed = sample_index_series.reindex(
- time=time.isel(sample=i_sample), method="nearest"
- ).values
- sample_index_series_reindexed = torch.tensor(
- sample_index_series_reindexed,
+ sample_index_series_window = sample_index_series.sel(
+ time=time.isel(sample=i_sample)
+ )
+ sample_index_series_window = torch.tensor(
+ sample_index_series_window.values,
device=get_device(),
dtype=torch.float32,
)
- self._index_variance[i_sample] += (
- sample_index_series_reindexed**2
- ).sum()
+ self._index_variance[i_sample] += (sample_index_series_window**2).sum()
for name, data in target_data.items():
if name not in self._target_covariances[i_sample]:
- self._target_covariances[i_sample][
- name
- ] = data_index_covariance(
- data[i_sample, :], sample_index_series_reindexed
+ self._target_covariances[i_sample][name] = (
+ data_index_covariance(
+ data[i_sample, :], sample_index_series_window
+ )
)
else:
- self._target_covariances[i_sample][
- name
- ] += data_index_covariance(
- data[i_sample, :], sample_index_series_reindexed
+ self._target_covariances[i_sample][name] += (
+ data_index_covariance(
+ data[i_sample, :], sample_index_series_window
+ )
)
for name, data in gen_data.items():
if name not in self._gen_covariances[i_sample]:
self._gen_covariances[i_sample][name] = data_index_covariance(
- data[i_sample, :], sample_index_series_reindexed
+ data[i_sample, :], sample_index_series_window
)
else:
self._gen_covariances[i_sample][name] += data_index_covariance(
- data[i_sample, :], sample_index_series_reindexed
+ data[i_sample, :], sample_index_series_window
)
def _compute_coefficients(
@@ -247,9 +247,9 @@ def get_logs(self, label: str) -> Dict[str, Any]:
wandb = WandB.get_instance()
images, metrics = {}, {}
for name in gen_coefficients.keys():
- if name in self._metadata:
- caption_name = self._metadata[name].long_name
- caption_units = self._metadata[name].units
+ if name in self._variable_metadata:
+ caption_name = self._variable_metadata[name].long_name
+ caption_units = self._variable_metadata[name].units
else:
caption_name = name
caption_units = "unknown units"
@@ -267,10 +267,9 @@ def get_logs(self, label: str) -> Dict[str, Any]:
)
images.update({f"coefficient_maps/{name}": coefficient_map})
rmse = float(
- root_mean_squared_error(
+ self._ops.area_weighted_rmse(
predicted=gen_coefficients[name],
truth=target_coefficients[name],
- weights=self._area_weights,
)
.cpu()
.numpy()
@@ -299,10 +298,45 @@ def get_logs(self, label: str) -> Dict[str, Any]:
logs.update({f"{label}{name}": metrics[name] for name in metrics.keys()})
return logs
+ def get_dataset(self) -> xr.Dataset:
+ """Get the coefficients as an xarray Dataset."""
+ target_coefficients, gen_coefficients = self._get_coefficients()
+ if target_coefficients is None or gen_coefficients is None:
+ return xr.Dataset()
+ target_coefficients_ds = xr.Dataset(
+ {
+ name: (
+ ["lat", "lon"],
+ target_coefficients[name].cpu().numpy(),
+ self._get_var_attrs(name),
+ )
+ for name in target_coefficients.keys()
+ }
+ ).expand_dims({"source": ["target"]})
+ gen_coefficients_ds = xr.Dataset(
+ {
+ name: (["lat", "lon"], gen_coefficients[name].cpu().numpy())
+ for name in gen_coefficients.keys()
+ }
+ ).expand_dims({"source": ["prediction"]})
+ return xr.concat([target_coefficients_ds, gen_coefficients_ds], dim="source")
+
+ def _get_var_attrs(self, name: str) -> Dict[str, str]:
+ if name in self._variable_metadata:
+ attrs_name = self._variable_metadata[name].long_name
+ attrs_units = self._variable_metadata[name].units
+ else:
+ attrs_name = name
+ attrs_units = "unknown units"
+ return {
+ "long_name": f"{attrs_name} regression coefficient with Nino3.4 index",
+ "units": f"{attrs_units} / K",
+ }
+
def get_sample_index_series(
index_data: xr.DataArray,
- initial_times: xr.DataArray,
+ initial_time: xr.DataArray,
n_forward_timesteps: int,
timestep: datetime.timedelta,
overlap_threshold: float = OVERLAP_THRESHOLD,
@@ -312,7 +346,7 @@ def get_sample_index_series(
Args:
index_data: ENSO index data with a time coordinate.
- initial_times: Initial times for each sample.
+ initial_time: Initial time for each sample.
n_forward_timesteps: Number of forward timesteps for each sample.
timestep: Timestep duration.
overlap_threshold: Required overlap of reference index with inference period.
@@ -321,38 +355,54 @@ def get_sample_index_series(
List of zero-mean index series for each sample, or None if the sample does
not overlap sufficiently with the reference index.
"""
- data_calendar = initial_times.dt.calendar
+ data_calendar = initial_time.dt.calendar
index_calendar = index_data.time.dt.calendar
if data_calendar != index_calendar:
index_data = index_data.convert_calendar(
calendar=data_calendar, dim="time", use_cftime=True
)
sample_index_series: List[Optional[xr.DataArray]] = []
- for initial_time in initial_times:
+ for initial_time_sample in initial_time:
duration = n_forward_timesteps * timestep
- end_time = initial_time + duration
+ end_time = initial_time_sample + duration
# select index data that overlaps with the inference period, plus a
# half-timestep buffer since we will later reindex with nearest neighbor
index_timestep_seconds = (
index_data.time[1].item() - index_data.time[0].item()
).total_seconds()
half_index_timestep = datetime.timedelta(seconds=index_timestep_seconds / 2)
- index_series = index_data.sel(
+ sample_index_data_selection = index_data.sel(
time=slice(
- initial_time - half_index_timestep,
+ initial_time_sample - half_index_timestep,
end_time + half_index_timestep,
)
)
- if index_series.sizes["time"] == 0:
+ if sample_index_data_selection.sizes["time"] == 0:
+ # no overlap
sample_index_series.append(None)
else:
- index_series_duration = (
- index_series.time[-1].item() - index_series.time[0].item()
+ sample_time = xr.cftime_range(
+ start=initial_time_sample.item(),
+ end=end_time.item(),
+ freq=f"{int(timestep.total_seconds())}s",
+ calendar=data_calendar,
)
- if index_series_duration > overlap_threshold * duration:
- index_series = index_series - index_series.mean()
- sample_index_series.append(index_series)
+ valid_sample_time = sample_time.where(
+ np.logical_and(
+ sample_time >= sample_index_data_selection.time[0],
+ sample_time <= sample_index_data_selection.time[-1],
+ ),
+ ).dropna()
+ if len(valid_sample_time) > len(sample_time) * overlap_threshold:
+ reindexed_series = sample_index_data_selection.reindex(
+ time=sample_time, method="nearest"
+ )
+ reindexed_series_zero_mean = reindexed_series - reindexed_series.mean(
+ "time"
+ )
+ sample_index_series.append(reindexed_series_zero_mean)
else:
+ # insufficient overlap
sample_index_series.append(None)
return sample_index_series
diff --git a/fme/fme/core/aggregator/inference/enso/index.py b/fme/fme/ace/aggregator/inference/enso/index.py
similarity index 100%
rename from fme/fme/core/aggregator/inference/enso/index.py
rename to fme/fme/ace/aggregator/inference/enso/index.py
diff --git a/fme/fme/core/aggregator/inference/enso/test_enso.py b/fme/fme/ace/aggregator/inference/enso/test_enso.py
similarity index 73%
rename from fme/fme/core/aggregator/inference/enso/test_enso.py
rename to fme/fme/ace/aggregator/inference/enso/test_enso.py
index 1ee1258..8bef935 100644
--- a/fme/fme/core/aggregator/inference/enso/test_enso.py
+++ b/fme/fme/ace/aggregator/inference/enso/test_enso.py
@@ -8,6 +8,7 @@
import xarray as xr
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from .enso import OVERLAP_THRESHOLD, EnsoCoefficientEvaluatorAggregator
@@ -39,14 +40,14 @@ def _get_data(
n_lon: int,
calendar: str = "julian",
):
- times = xr.cftime_range(
+ time = xr.cftime_range(
start="2000-01-01",
periods=n_times,
freq="6h",
calendar=calendar,
)
enso_index = xr.DataArray(
- _data_generator(scale, n_times), dims=["time"], coords={"time": times}
+ _data_generator(scale, n_times), dims=["time"], coords={"time": time}
)
target_data = { # make target data perfectly correlated
"a": torch.tile(
@@ -60,17 +61,17 @@ def _get_data(
[n_samples, 1, n_lat, n_lon],
).to(device=get_device()),
}
- sample_times = xr.concat(
+ sample_time = xr.concat(
[
xr.DataArray(
- times.values,
+ time.values,
dims=["time"],
)
for _ in range(n_samples)
],
dim="sample",
)
- return enso_index, sample_times, target_data, gen_data
+ return enso_index, sample_time, target_data, gen_data
@pytest.mark.parametrize("scaling", [0.5, 1.0, 2.0])
@@ -81,10 +82,11 @@ def test_enso_coefficient_aggregator_values(scaling):
Check that:
- The aggregator maintains a zero-mean subset of the enso index for
each sample.
- - The target and gen coefficients are scaled versions of 1.0 and -1.0 for
- perfectly correlated and anti-correlated data, respectively.
- - The global-mean coefficient RMSE is scaled from 2.0 (perfect
+ - The logged target and gen coefficients are scaled versions of 1.0 and
+ -1.0 for perfectly correlated and anti-correlated data,respectively.
+ - The logged global-mean coefficient RMSE is scaled from 2.0 (perfect
correlation minus perfect anti-correlation).
+ - The above two are also true via the `get_dataset` method.
Args:
scaling: How much to scale up or down the target and generated data;
@@ -95,28 +97,31 @@ def test_enso_coefficient_aggregator_values(scaling):
"""
n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3
scale = 3
- area_weights = torch.ones([n_lat, n_lon])
+ area_weights = torch.ones([n_lat, n_lon], device=get_device())
# get data that doesn't vary in space, but varies in time with the ENSO index
- enso_index, sample_times, target_data, gen_data = _get_data(
+ enso_index, sample_time, target_data, gen_data = _get_data(
scale, n_samples, n_times, n_lat, n_lon
)
with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index):
enso_agg = EnsoCoefficientEvaluatorAggregator(
- initial_times=sample_times.isel(time=0),
+ initial_time=sample_time.isel(time=0),
n_forward_timesteps=(n_times - 1),
timestep=datetime.timedelta(hours=6),
- area_weights=area_weights,
+ gridded_operations=LatLonOperations(area_weights),
)
assert len(enso_agg._sample_index_series) == n_samples
for index_values in enso_agg._sample_index_series:
+ assert index_values is not None
# check that the index values are zero-mean for each sample
assert np.isclose(index_values.mean().item(), 0.0)
target_data["a"] *= scaling
gen_data["a"] *= scaling
- enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data)
- enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data)
+ enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data)
+ enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data)
coefficients = enso_agg._get_coefficients()
target_coefficients, gen_coefficients = coefficients
+ assert target_coefficients is not None
+ assert gen_coefficients is not None
# check that the target coefficients are 1.0 * scaling (perfectly correlated)
assert torch.allclose(target_coefficients["a"], torch.tensor(scaling))
# check that the gen coefficients are -1.0 * scaling (perfectly anti-correlated)
@@ -127,6 +132,16 @@ def test_enso_coefficient_aggregator_values(scaling):
np.testing.assert_almost_equal(
logs["enso_coefficients/rmse/a"], 2.0 * scaling, decimal=5
)
+ enso_dataset = enso_agg.get_dataset()
+ # check that the coefficients are as expected in the dataset also
+ np.testing.assert_array_almost_equal(
+ enso_dataset["a"].sel(source="target").values,
+ target_coefficients["a"].cpu().numpy(),
+ )
+ np.testing.assert_array_almost_equal(
+ enso_dataset["a"].sel(source="prediction").values,
+ gen_coefficients["a"].cpu().numpy(),
+ )
@pytest.mark.parametrize("shift", [1.5, 0.95, 0.05, 0.0])
@@ -138,27 +153,29 @@ def test_enso_index_inference_overlap(shift):
n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3
data_scale = 3
area_weights = torch.ones([n_lat, n_lon])
- enso_index, sample_times, target_data, gen_data = _get_data(
+ enso_index, sample_time, target_data, gen_data = _get_data(
data_scale, n_samples, n_times, n_lat, n_lon
)
# shift the sample times so they only partially overlap the reference index
index_duration = enso_index.time[-1].item() - enso_index.time[0].item()
offset_seconds = shift * index_duration.total_seconds()
- sample_times += datetime.timedelta(seconds=offset_seconds)
+ sample_time += datetime.timedelta(seconds=offset_seconds)
with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index):
enso_agg = EnsoCoefficientEvaluatorAggregator(
- initial_times=sample_times.isel(time=0),
+ initial_time=sample_time.isel(time=0),
n_forward_timesteps=(n_times - 1),
timestep=datetime.timedelta(hours=6),
- area_weights=area_weights,
+ gridded_operations=LatLonOperations(area_weights),
)
- enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data)
+ enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data)
target_coefficients, gen_coefficients = enso_agg._get_coefficients()
- overlap = 1.0 - shift
+ overlap = max(1.0 - shift, 0.0)
if overlap < OVERLAP_THRESHOLD: # should be empty dict
assert not target_coefficients
assert not gen_coefficients
else:
+ assert target_coefficients is not None
+ assert gen_coefficients is not None
assert isinstance(target_coefficients["a"], torch.Tensor)
assert isinstance(gen_coefficients["a"], torch.Tensor)
@@ -180,17 +197,17 @@ def test_enso_agg_calendar(calendar):
n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3
data_scale = 3
area_weights = torch.ones([n_lat, n_lon])
- enso_index, sample_times, target_data, gen_data = _get_data(
+ enso_index, sample_time, target_data, gen_data = _get_data(
data_scale, n_samples, n_times, n_lat, n_lon, calendar=calendar
)
enso_index = enso_index.convert_calendar("julian", dim="time", use_cftime=True)
with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index):
enso_agg = EnsoCoefficientEvaluatorAggregator(
- initial_times=sample_times.isel(time=0),
+ initial_time=sample_time.isel(time=0),
n_forward_timesteps=(n_times - 1),
timestep=datetime.timedelta(hours=6),
- area_weights=area_weights,
+ gridded_operations=LatLonOperations(area_weights),
)
- enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data)
+ enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data)
target_coefficients, gen_coefficients = enso_agg._get_coefficients()
assert (target_coefficients is not None) and (gen_coefficients is not None)
diff --git a/fme/fme/core/aggregator/inference/histogram.py b/fme/fme/ace/aggregator/inference/histogram.py
similarity index 97%
rename from fme/fme/core/aggregator/inference/histogram.py
rename to fme/fme/ace/aggregator/inference/histogram.py
index c7d7354..5338bb3 100644
--- a/fme/fme/core/aggregator/inference/histogram.py
+++ b/fme/fme/ace/aggregator/inference/histogram.py
@@ -12,7 +12,6 @@ def __init__(self):
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
diff --git a/fme/fme/ace/aggregator/inference/main.py b/fme/fme/ace/aggregator/inference/main.py
new file mode 100644
index 0000000..83589f0
--- /dev/null
+++ b/fme/fme/ace/aggregator/inference/main.py
@@ -0,0 +1,719 @@
+import dataclasses
+import datetime
+import warnings
+from typing import (
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Protocol,
+ Sequence,
+ Union,
+)
+
+import torch
+import xarray as xr
+
+from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
+from fme.core.coordinates import (
+ HorizontalCoordinates,
+ HybridSigmaPressureCoordinate,
+ LatLonCoordinates,
+)
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.generics.aggregator import (
+ InferenceAggregatorABC,
+ InferenceLog,
+ InferenceLogs,
+)
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.typing_ import TensorDict, TensorMapping
+from fme.core.wandb import Table, WandB
+
+from ..one_step.reduced import MeanAggregator as OneStepMeanAggregator
+from .annual import GlobalMeanAnnualAggregator
+from .enso import EnsoCoefficientEvaluatorAggregator
+from .histogram import HistogramAggregator
+from .reduced import MeanAggregator, SingleTargetMeanAggregator
+from .seasonal import SeasonalAggregator
+from .spectrum import PairedSphericalPowerSpectrumAggregator
+from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator
+from .video import VideoAggregator
+from .zonal_mean import ZonalMeanAggregator
+
+wandb = WandB.get_instance()
+APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730)
+SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800)
+
+
+class _Aggregator(Protocol):
+ @torch.no_grad()
+ def record_batch(
+ self,
+ data: TensorMapping,
+ ): ...
+
+ @torch.no_grad()
+ def get_logs(self, label: str): ...
+
+ @torch.no_grad()
+ def get_dataset(self) -> xr.Dataset: ...
+
+
+class _EvaluatorAggregator(Protocol):
+ @torch.no_grad()
+ def record_batch(
+ self,
+ target_data: TensorMapping,
+ gen_data: TensorMapping,
+ target_data_norm: TensorMapping,
+ gen_data_norm: TensorMapping,
+ i_time_start: int = 0,
+ ): ...
+
+ @torch.no_grad()
+ def get_logs(self, label: str): ...
+
+ @torch.no_grad()
+ def get_dataset(self) -> xr.Dataset: ...
+
+
+class _TimeDependentAggregator(Protocol):
+ @torch.no_grad()
+ def record_batch(
+ self,
+ time: xr.DataArray,
+ data: TensorMapping,
+ ): ...
+
+ @torch.no_grad()
+ def get_logs(self, label: str): ...
+
+ @torch.no_grad()
+ def get_dataset(self) -> xr.Dataset: ...
+
+
+class _TimeDependentEvaluatorAggregator(Protocol):
+ @torch.no_grad()
+ def record_batch(
+ self,
+ time: xr.DataArray,
+ target_data: TensorMapping,
+ gen_data: TensorMapping,
+ ): ...
+
+ @torch.no_grad()
+ def get_logs(self, label: str): ...
+
+ @torch.no_grad()
+ def get_dataset(self) -> xr.Dataset: ...
+
+
+@dataclasses.dataclass
+class InferenceEvaluatorAggregatorConfig:
+ """
+ Configuration for inference evaluator aggregator.
+
+ Parameters:
+ log_histograms: Whether to log histograms of the targets and predictions.
+ log_video: Whether to log videos of the state evolution.
+ log_extended_video: Whether to log wandb videos of the predictions with
+ statistical metrics, only done if log_video is True.
+ log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a
+ time dimension.
+ log_seasonal_means: Whether to log seasonal mean metrics and images.
+ log_global_mean_time_series: Whether to log global mean time series metrics.
+ log_global_mean_norm_time_series: Whether to log the normalized global mean
+ time series metrics.
+ monthly_reference_data: Path to monthly reference data to compare against.
+ time_mean_reference_data: Path to reference time means to compare against.
+ """
+
+ log_histograms: bool = False
+ log_video: bool = False
+ log_extended_video: bool = False
+ log_zonal_mean_images: bool = True
+ log_seasonal_means: bool = False
+ log_global_mean_time_series: bool = True
+ log_global_mean_norm_time_series: bool = True
+ monthly_reference_data: Optional[str] = None
+ time_mean_reference_data: Optional[str] = None
+
+ def build(
+ self,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ horizontal_coordinates: HorizontalCoordinates,
+ timestep: datetime.timedelta,
+ n_timesteps: int,
+ initial_time: xr.DataArray,
+ normalize: Callable[[TensorMapping], TensorDict],
+ record_step_20: bool = False,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ channel_mean_names: Optional[Sequence[str]] = None,
+ ) -> "InferenceEvaluatorAggregator":
+ if self.monthly_reference_data is None:
+ monthly_reference_data = None
+ else:
+ monthly_reference_data = xr.open_dataset(self.monthly_reference_data)
+ if self.time_mean_reference_data is None:
+ time_mean = None
+ else:
+ time_mean = xr.open_dataset(self.time_mean_reference_data)
+
+ if n_timesteps > 2**15 and self.log_zonal_mean_images:
+ # matplotlib raises an error if image size is too large, and we plot
+ # one pixel per timestep in the zonal mean images.
+ warnings.warn(
+ "Disabling zonal mean images logging due to large number of timesteps"
+ f" (n_timesteps={n_timesteps}). Set log_zonal_mean_images=False or "
+ "decrease n_timesteps to below 2**15 to avoid this warning."
+ )
+ log_zonal_mean_images = False
+ else:
+ log_zonal_mean_images = self.log_zonal_mean_images
+
+ return InferenceEvaluatorAggregator(
+ vertical_coordinate=vertical_coordinate,
+ horizontal_coordinates=horizontal_coordinates,
+ timestep=timestep,
+ n_timesteps=n_timesteps,
+ initial_time=initial_time,
+ log_histograms=self.log_histograms,
+ log_video=self.log_video,
+ enable_extended_videos=self.log_extended_video,
+ log_zonal_mean_images=log_zonal_mean_images,
+ log_seasonal_means=self.log_seasonal_means,
+ log_global_mean_time_series=self.log_global_mean_time_series,
+ log_global_mean_norm_time_series=self.log_global_mean_norm_time_series,
+ monthly_reference_data=monthly_reference_data,
+ time_mean_reference_data=time_mean,
+ record_step_20=record_step_20,
+ variable_metadata=variable_metadata,
+ channel_mean_names=channel_mean_names,
+ normalize=normalize,
+ )
+
+
+class InferenceEvaluatorAggregator(
+ InferenceAggregatorABC[
+ Union[PairedData, PrognosticState],
+ PairedData,
+ ]
+):
+ """
+ Aggregates statistics for inference comparing a generated and target series.
+
+ To use, call `record_batch` on the results of each batch, then call
+ `get_logs` to get a dictionary of statistics when you're done.
+ """
+
+ def __init__(
+ self,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ horizontal_coordinates: HorizontalCoordinates,
+ timestep: datetime.timedelta,
+ n_timesteps: int,
+ initial_time: xr.DataArray,
+ normalize: Callable[[TensorMapping], TensorDict],
+ record_step_20: bool = False,
+ log_video: bool = False,
+ enable_extended_videos: bool = False,
+ log_zonal_mean_images: bool = False,
+ log_seasonal_means: bool = False,
+ log_global_mean_time_series: bool = True,
+ log_global_mean_norm_time_series: bool = True,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ monthly_reference_data: Optional[xr.Dataset] = None,
+ log_histograms: bool = False,
+ time_mean_reference_data: Optional[xr.Dataset] = None,
+ channel_mean_names: Optional[Sequence[str]] = None,
+ ):
+ """
+ Args:
+ vertical_coordinate: Data vertical coordinate.
+ horizontal_coordinates: Data horizontal coordinates
+ timestep: Timestep of the model.
+ n_timesteps: Number of timesteps of inference that will be run.
+ initial_time: Initial time for each sample.
+ normalize: Normalization function to use.
+ record_step_20: Whether to record the mean of the 20th steps.
+ log_video: Whether to log videos of the state evolution.
+ enable_extended_videos: Whether to log videos of statistical
+ metrics of state evolution
+ log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a
+ time dimension.
+ log_seasonal_means: Whether to log seasonal means metrics and images.
+ log_global_mean_time_series: Whether to log global mean time series metrics.
+ log_global_mean_norm_time_series: Whether to log the normalized global mean
+ time series metrics.
+ variable_metadata: Mapping of variable names their metadata that will
+ used in generating logged image captions.
+ monthly_reference_data: Reference monthly data for computing target stats.
+ log_histograms: Whether to aggregate histograms.
+ data_grid: The grid type of the data, used for spherical power spectrum.
+ time_mean_reference_data: Reference time means for computing bias stats.
+ channel_mean_names: Names over which to compute channel means. If not
+ provided, all available variables will be used.
+ """
+ self._channel_mean_names = channel_mean_names
+ self._aggregators: Dict[str, _EvaluatorAggregator] = {}
+ self._time_dependent_aggregators: Dict[
+ str, _TimeDependentEvaluatorAggregator
+ ] = {}
+ ops = horizontal_coordinates.gridded_operations
+ self._log_time_series = (
+ log_global_mean_time_series or log_global_mean_norm_time_series
+ )
+ if log_global_mean_time_series:
+ self._aggregators["mean"] = MeanAggregator(
+ ops,
+ target="denorm",
+ n_timesteps=n_timesteps,
+ variable_metadata=variable_metadata,
+ )
+ if log_global_mean_norm_time_series:
+ self._aggregators["mean_norm"] = MeanAggregator(
+ ops,
+ target="norm",
+ n_timesteps=n_timesteps,
+ variable_metadata=variable_metadata,
+ )
+ if record_step_20:
+ self._aggregators["mean_step_20"] = OneStepMeanAggregator(
+ ops, target_time=20
+ )
+ if isinstance(horizontal_coordinates, LatLonCoordinates):
+ if log_zonal_mean_images:
+ self._aggregators["zonal_mean"] = ZonalMeanAggregator(
+ n_timesteps=n_timesteps,
+ variable_metadata=variable_metadata,
+ )
+ self._aggregators["spherical_power_spectrum"] = (
+ PairedSphericalPowerSpectrumAggregator(
+ horizontal_coordinates.area_weights.shape[-2],
+ horizontal_coordinates.area_weights.shape[-1],
+ horizontal_coordinates.grid,
+ )
+ )
+ if log_video:
+ self._aggregators["video"] = VideoAggregator(
+ n_timesteps=n_timesteps,
+ enable_extended_videos=enable_extended_videos,
+ variable_metadata=variable_metadata,
+ )
+ self._aggregators["time_mean"] = TimeMeanEvaluatorAggregator(
+ ops,
+ horizontal_dims=horizontal_coordinates.dims,
+ variable_metadata=variable_metadata,
+ reference_means=time_mean_reference_data,
+ )
+ self._aggregators["time_mean_norm"] = TimeMeanEvaluatorAggregator(
+ ops,
+ horizontal_dims=horizontal_coordinates.dims,
+ target="norm",
+ variable_metadata=variable_metadata,
+ )
+ if log_histograms:
+ self._aggregators["histogram"] = HistogramAggregator()
+ if log_seasonal_means:
+ self._time_dependent_aggregators["seasonal"] = SeasonalAggregator(
+ ops=ops,
+ variable_metadata=variable_metadata,
+ )
+ if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS:
+ self._time_dependent_aggregators["annual"] = GlobalMeanAnnualAggregator(
+ ops=ops,
+ timestep=timestep,
+ variable_metadata=variable_metadata,
+ monthly_reference_data=monthly_reference_data,
+ )
+ if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS:
+ self._time_dependent_aggregators["enso_coefficient"] = (
+ EnsoCoefficientEvaluatorAggregator(
+ initial_time,
+ n_timesteps - 1,
+ timestep,
+ gridded_operations=ops,
+ variable_metadata=variable_metadata,
+ )
+ )
+ self._summary_aggregators = {
+ name: agg
+ for name, agg in list(self._aggregators.items())
+ + list(self._time_dependent_aggregators.items())
+ if name not in ["mean", "mean_norm"]
+ }
+ self._n_timesteps_seen = 0
+ self._normalize = normalize
+
+ @property
+ def log_time_series(self) -> bool:
+ return self._log_time_series
+
+ @torch.no_grad()
+ def record_batch(
+ self,
+ data: PairedData,
+ ) -> InferenceLogs:
+ if len(data.prediction) == 0:
+ raise ValueError("No prediction values in data")
+ if len(data.target) == 0:
+ raise ValueError("No target values in data")
+ target_data = {k: v for k, v in data.target.items() if k in data.prediction}
+ target_data_norm = self._normalize(target_data)
+ gen_data_norm = self._normalize(data.prediction)
+ for aggregator in self._aggregators.values():
+ aggregator.record_batch(
+ target_data=target_data,
+ gen_data=data.prediction,
+ target_data_norm=target_data_norm,
+ gen_data_norm=gen_data_norm,
+ i_time_start=self._n_timesteps_seen,
+ )
+ for time_dependent_aggregator in self._time_dependent_aggregators.values():
+ time_dependent_aggregator.record_batch(
+ time=data.time,
+ target_data=target_data,
+ gen_data=data.prediction,
+ )
+ n_times = data.time.shape[1]
+ logs = self._get_inference_logs_slice(
+ step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times),
+ )
+ self._n_timesteps_seen += n_times
+ return logs
+
+ def record_initial_condition(
+ self,
+ initial_condition: Union[PairedData, PrognosticState],
+ ) -> InferenceLogs:
+ if self._n_timesteps_seen != 0:
+ raise RuntimeError(
+ "record_initial_condition may only be called once, "
+ "before recording any batches"
+ )
+ if isinstance(initial_condition, PairedData):
+ target_data = initial_condition.target
+ target_data_norm = self._normalize(target_data)
+ gen_data = initial_condition.prediction
+ gen_data_norm = self._normalize(gen_data)
+ n_times = initial_condition.time.shape[1]
+ else:
+ batch_data = initial_condition.as_batch_data()
+ target_data = batch_data.data
+ target_data_norm = self._normalize(target_data)
+ gen_data = target_data
+ gen_data_norm = target_data_norm
+ n_times = batch_data.time.shape[1]
+ for aggregator_name in ["mean", "mean_norm"]:
+ aggregator = self._aggregators.get(aggregator_name)
+ if aggregator is not None:
+ aggregator.record_batch(
+ target_data=target_data,
+ gen_data=gen_data,
+ target_data_norm=target_data_norm,
+ gen_data_norm=gen_data_norm,
+ i_time_start=0,
+ )
+ logs = self._get_inference_logs_slice(
+ step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times),
+ )
+ self._n_timesteps_seen = n_times
+ return logs
+
+ def get_summary_logs(self) -> InferenceLog:
+ logs = {}
+ for name, aggregator in self._summary_aggregators.items():
+ logs.update(aggregator.get_logs(label=name))
+ return logs
+
+ @torch.no_grad()
+ def _get_logs(self):
+ """
+ Returns logs as can be reported to WandB.
+ """
+ logs = {}
+ for name, aggregator in self._aggregators.items():
+ logs.update(aggregator.get_logs(label=name))
+ for name, time_dependent_aggregator in self._time_dependent_aggregators.items():
+ logs.update(time_dependent_aggregator.get_logs(label=name))
+ return logs
+
+ @torch.no_grad()
+ def _get_inference_logs_slice(self, step_slice: slice):
+ """
+ Returns a subset of the time series for applicable metrics
+ for a specific slice of as can be reported to WandB.
+
+ Args:
+ step_slice: Timestep slice to determine the time series subset.
+
+ Returns:
+ Tuple of start index and list of logs.
+ """
+ logs = {}
+ for name, aggregator in self._aggregators.items():
+ if isinstance(aggregator, MeanAggregator):
+ logs.update(aggregator.get_logs(label=name, step_slice=step_slice))
+ return to_inference_logs(logs)
+
+ @torch.no_grad()
+ def get_datasets(
+ self, excluded_aggregators: Optional[Iterable[str]] = None
+ ) -> Dict[str, xr.Dataset]:
+ """
+ Returns datasets from combined aggregators.
+
+ Args:
+ excluded_aggregators: aggregator names for which `get_dataset`
+ should not be called and no output should be returned.
+
+ Returns:
+ Dictionary of datasets from aggregators.
+ """
+ if excluded_aggregators is None:
+ excluded_aggregators = []
+
+ combined_aggregators: Dict[
+ str, Union[_Aggregator, _TimeDependentAggregator]
+ ] = {
+ **self._aggregators,
+ **self._time_dependent_aggregators,
+ }
+ return {
+ name: agg.get_dataset()
+ for name, agg in combined_aggregators.items()
+ if name not in excluded_aggregators
+ }
+
+
+def to_inference_logs(
+ log: Mapping[str, Union[Table, float, int]],
+) -> List[Dict[str, Union[float, int]]]:
+ # we have a dictionary which contains WandB tables
+ # which we will convert to a list of dictionaries, one for each
+ # row in the tables. Any scalar values will be reported in the last
+ # dictionary.
+ n_rows = 0
+ for val in log.values():
+ if isinstance(val, Table):
+ n_rows = max(n_rows, len(val.data))
+ logs: List[Dict[str, Union[float, int]]] = []
+ for i in range(max(1, n_rows)): # need at least one for non-series values
+ logs.append({})
+ for key, val in log.items():
+ if isinstance(val, Table):
+ for i, row in enumerate(val.data):
+ for j, col in enumerate(val.columns):
+ key_without_table_name = key[: key.rfind("/")]
+ logs[i][f"{key_without_table_name}/{col}"] = row[j]
+ else:
+ logs[-1][key] = val
+ return logs
+
+
+def table_to_logs(table: Table) -> List[Dict[str, Union[float, int]]]:
+ """
+ Converts a WandB table into a list of dictionaries.
+ """
+ logs = []
+ for row in table.data:
+ logs.append({table.columns[i]: row[i] for i in range(len(row))})
+ return logs
+
+
+@dataclasses.dataclass
+class InferenceAggregatorConfig:
+ """
+ Configuration for inference aggregator.
+
+ Parameters:
+ time_mean_reference_data: Path to reference time means to compare against.
+ log_global_mean_time_series: Whether to log global mean time series metrics.
+ """
+
+ time_mean_reference_data: Optional[str] = None
+ log_global_mean_time_series: bool = True
+
+ def build(
+ self,
+ gridded_operations: GriddedOperations,
+ n_timesteps: int,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ ) -> "InferenceAggregator":
+ if self.time_mean_reference_data is not None:
+ time_means = xr.open_dataset(self.time_mean_reference_data)
+ else:
+ time_means = None
+ return InferenceAggregator(
+ gridded_operations=gridded_operations,
+ n_timesteps=n_timesteps,
+ variable_metadata=variable_metadata,
+ time_mean_reference_data=time_means,
+ log_global_mean_time_series=self.log_global_mean_time_series,
+ )
+
+
+class InferenceAggregator(
+ InferenceAggregatorABC[
+ PrognosticState,
+ BatchData,
+ ]
+):
+ """
+ Aggregates statistics on a single timeseries of data.
+
+ To use, call `record_batch` on the results of each batch, then call
+ `get_logs` to get a dictionary of statistics when you're done.
+ """
+
+ def __init__(
+ self,
+ gridded_operations: GriddedOperations,
+ n_timesteps: int,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ time_mean_reference_data: Optional[xr.Dataset] = None,
+ log_global_mean_time_series: bool = True,
+ ):
+ """
+ Args:
+ gridded_operations: Gridded operations for computing horizontal reductions.
+ n_timesteps: Number of timesteps in the model.
+ variable_metadata: Mapping of variable names their metadata that will
+ used in generating logged image captions.
+ time_mean_reference_data: Reference time means for computing bias stats.
+ log_global_mean_time_series: Whether to log global mean time series metrics.
+ """
+ self._log_time_series = log_global_mean_time_series
+ aggregators: Dict[str, _Aggregator] = {}
+ if log_global_mean_time_series:
+ aggregators["mean"] = SingleTargetMeanAggregator(
+ gridded_operations,
+ n_timesteps=n_timesteps,
+ )
+ aggregators["time_mean"] = TimeMeanAggregator(
+ gridded_operations=gridded_operations,
+ variable_metadata=variable_metadata,
+ reference_means=time_mean_reference_data,
+ )
+ self._aggregators = aggregators
+ self._summary_aggregators = {"time_mean": aggregators["time_mean"]}
+ self._n_timesteps_seen = 0
+
+ @property
+ def log_time_series(self) -> bool:
+ return self._log_time_series
+
+ @torch.no_grad()
+ def record_batch(
+ self,
+ data: BatchData,
+ ) -> InferenceLogs:
+ """
+ Record a batch of data.
+
+ Args:
+ data: Batch of data to record.
+ """
+ if len(data.data) == 0:
+ raise ValueError("data is empty")
+ for aggregator in self._aggregators.values():
+ aggregator.record_batch(
+ data=data.data,
+ i_time_start=self._n_timesteps_seen,
+ )
+ n_times = data.time.shape[1]
+ logs = self._get_inference_logs_slice(
+ step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times),
+ )
+ self._n_timesteps_seen += n_times
+ return logs
+
+ def record_initial_condition(
+ self,
+ initial_condition: PrognosticState,
+ ) -> InferenceLogs:
+ if self._n_timesteps_seen != 0:
+ raise RuntimeError(
+ "record_initial_condition may only be called once, "
+ "before recording any batches"
+ )
+ batch_data = initial_condition.as_batch_data()
+ if "mean" in self._aggregators:
+ self._aggregators["mean"].record_batch(
+ data=batch_data.data,
+ i_time_start=0,
+ )
+ n_times = batch_data.time.shape[1]
+ logs = self._get_inference_logs_slice(
+ step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times),
+ )
+ self._n_timesteps_seen = n_times
+ return logs
+
+ def get_summary_logs(self) -> InferenceLog:
+ logs = {}
+ for name, aggregator in self._summary_aggregators.items():
+ logs.update(aggregator.get_logs(label=name))
+ return logs
+
+ @torch.no_grad()
+ def _get_logs(self):
+ """
+ Returns logs as can be reported to WandB.
+ """
+ logs = {}
+ for name, aggregator in self._aggregators.items():
+ logs.update(aggregator.get_logs(label=name))
+ return logs
+
+ @torch.no_grad()
+ def _get_inference_logs(self) -> List[Dict[str, Union[float, int]]]:
+ """
+ Returns a list of logs to report to WandB.
+
+ This is done because in inference, we use the wandb step
+ as the time step, meaning we need to re-organize the logged data
+ from tables into a list of dictionaries.
+ """
+ return to_inference_logs(self._get_logs())
+
+ @torch.no_grad()
+ def _get_inference_logs_slice(self, step_slice: slice):
+ """
+ Returns a subset of the time series for applicable metrics
+ for a specific slice of as can be reported to WandB.
+
+ Args:
+ step_slice: Timestep slice to determine the time series subset.
+ """
+ logs = {}
+ for name, aggregator in self._aggregators.items():
+ if isinstance(aggregator, SingleTargetMeanAggregator):
+ logs.update(aggregator.get_logs(label=name, step_slice=step_slice))
+ return to_inference_logs(logs)
+
+ @torch.no_grad()
+ def get_datasets(
+ self, excluded_aggregators: Optional[Iterable[str]] = None
+ ) -> Dict[str, xr.Dataset]:
+ """
+ Returns datasets from combined aggregators.
+
+ Args:
+ excluded_aggregators: aggregator names for which `get_dataset`
+ should not be called and no output should be returned.
+
+ Returns:
+ Dictionary of datasets from aggregators.
+ """
+ if excluded_aggregators is None:
+ excluded_aggregators = []
+
+ return {
+ name: agg.get_dataset()
+ for name, agg in self._aggregators.items()
+ if name not in excluded_aggregators
+ }
diff --git a/fme/fme/core/aggregator/inference/reduced.py b/fme/fme/ace/aggregator/inference/reduced.py
similarity index 68%
rename from fme/fme/core/aggregator/inference/reduced.py
rename to fme/fme/ace/aggregator/inference/reduced.py
index 6eceda5..61b5680 100644
--- a/fme/fme/core/aggregator/inference/reduced.py
+++ b/fme/fme/ace/aggregator/inference/reduced.py
@@ -1,15 +1,15 @@
import dataclasses
+from collections import defaultdict
from typing import Dict, List, Literal, Mapping, Optional, Protocol
import numpy as np
import torch
import xarray as xr
-from fme.core import metrics
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
-from fme.core.metrics import Dimension
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from fme.core.wandb import Table, WandB
@@ -70,10 +70,7 @@ def __call__(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
- weights: Optional[torch.Tensor] = None,
- dim: Dimension = (),
- ) -> torch.Tensor:
- ...
+ ) -> torch.Tensor: ...
class AreaWeightedSingleTargetFunction(Protocol):
@@ -84,10 +81,7 @@ class AreaWeightedSingleTargetFunction(Protocol):
def __call__(
self,
tensor: torch.Tensor,
- weights: Optional[torch.Tensor] = None,
- dim: Dimension = (),
- ) -> torch.Tensor:
- ...
+ ) -> torch.Tensor: ...
def compute_metric_on(
@@ -102,13 +96,11 @@ def compute_metric_on(
def metric_wrapper(
truth: torch.Tensor,
predicted: torch.Tensor,
- weights: Optional[torch.Tensor] = None,
- dim: Dimension = (),
) -> torch.Tensor:
if source == "gen":
- return metric(predicted, weights=weights, dim=dim)
+ return metric(predicted)
elif source == "target":
- return metric(truth, weights=weights, dim=dim)
+ return metric(truth)
return metric_wrapper
@@ -120,12 +112,10 @@ class AreaWeightedReducedMetric:
def __init__(
self,
- area_weights: torch.Tensor,
device: torch.device,
compute_metric: AreaWeightedFunction,
n_timesteps: int,
):
- self._area_weights = area_weights
self._compute_metric = compute_metric
self._total: Optional[torch.Tensor] = None
self._n_batches = torch.zeros(
@@ -143,17 +133,19 @@ def record(self, target: torch.Tensor, gen: torch.Tensor, i_time_start: int):
i_time_start: The index of the first timestep in the batch.
"""
time_dim = 1
- if target.shape[time_dim] >= gen.shape[time_dim]:
- new_value = self._compute_metric(
- target, gen, weights=self._area_weights, dim=(-2, -1)
- ).mean(dim=0)
- if self._total is None:
- self._total = torch.zeros(
- [self._n_timesteps], dtype=new_value.dtype, device=self._device
- )
- time_slice = slice(i_time_start, i_time_start + gen.shape[1])
- self._total[time_slice] += new_value
- self._n_batches[time_slice] += 1
+ if target.shape != gen.shape:
+ raise RuntimeError(
+ "target and gen must have the same shape, got "
+ f"{target.shape} and {gen.shape}"
+ )
+ new_value = self._compute_metric(truth=target, predicted=gen).mean(dim=0)
+ if self._total is None:
+ self._total = torch.zeros(
+ [self._n_timesteps], dtype=new_value.dtype, device=self._device
+ )
+ time_slice = slice(i_time_start, i_time_start + gen.shape[time_dim])
+ self._total[time_slice] += new_value
+ self._n_batches[time_slice] += 1
def get(self) -> torch.Tensor:
"""Returns the mean metric across recorded batches."""
@@ -165,12 +157,12 @@ def get(self) -> torch.Tensor:
class MeanAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ gridded_operations: GriddedOperations,
target: Literal["norm", "denorm"],
n_timesteps: int,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
- self._area_weights = area_weights
+ self._gridded_operations = gridded_operations
self._variable_metrics: Optional[Dict[str, Dict[str, MeanMetric]]] = None
self._shape_x = None
self._shape_y = None
@@ -178,89 +170,81 @@ def __init__(
self._n_timesteps = n_timesteps
self._dist = Distributed.get_instance()
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
def _get_variable_metrics(self, gen_data: TensorMapping):
if self._variable_metrics is None:
- self._variable_metrics = {
- "weighted_rmse": {},
- "weighted_mean_gen": {},
- "weighted_mean_target": {},
- "weighted_bias": {},
- "weighted_std_gen": {},
- }
-
- if self._target == "denorm":
- # redundant for the "norm" case
- self._variable_metrics["weighted_grad_mag_percent_diff"] = {}
+ self._variable_metrics = {}
device = get_device()
- for key in gen_data:
- self._variable_metrics["weighted_rmse"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ self._variable_metrics["weighted_rmse"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
- compute_metric=metrics.root_mean_squared_error,
+ compute_metric=self._gridded_operations.area_weighted_rmse,
n_timesteps=self._n_timesteps,
)
- if self._target == "denorm":
- self._variable_metrics["weighted_grad_mag_percent_diff"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ )
+ if self._target == "denorm":
+ self._variable_metrics["weighted_grad_mag_percent_diff"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
- compute_metric=metrics.gradient_magnitude_percent_diff,
+ compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff, # noqa: E501
n_timesteps=self._n_timesteps,
)
- self._variable_metrics["weighted_mean_gen"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ )
+ self._variable_metrics["weighted_mean_gen"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
- source="gen", metric=metrics.weighted_mean
+ source="gen",
+ metric=lambda tensor: (
+ self._gridded_operations.area_weighted_mean(tensor)
+ ),
),
n_timesteps=self._n_timesteps,
)
- self._variable_metrics["weighted_mean_target"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ )
+ self._variable_metrics["weighted_mean_target"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
- source="target", metric=metrics.weighted_mean
+ source="target",
+ metric=lambda tensor: (
+ self._gridded_operations.area_weighted_mean(tensor)
+ ),
),
n_timesteps=self._n_timesteps,
)
- self._variable_metrics["weighted_bias"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ )
+ self._variable_metrics["weighted_bias"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
- compute_metric=metrics.weighted_mean_bias,
+ compute_metric=self._gridded_operations.area_weighted_mean_bias,
n_timesteps=self._n_timesteps,
)
- self._variable_metrics["weighted_std_gen"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
+ )
+ self._variable_metrics["weighted_std_gen"] = defaultdict(
+ lambda: AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
- source="gen", metric=metrics.weighted_std
+ source="gen",
+ metric=(
+ lambda tensor: self._gridded_operations.area_weighted_std(
+ tensor
+ )
+ ),
),
n_timesteps=self._n_timesteps,
)
-
+ )
return self._variable_metrics
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
@@ -279,7 +263,7 @@ def record_batch(
i_time_start=i_time_start,
)
- def _get_series_data(self) -> List[_SeriesData]:
+ def _get_series_data(self, step_slice: Optional[slice] = None) -> List[_SeriesData]:
"""Converts internally stored variable_metrics to a list."""
if self._variable_metrics is None:
raise ValueError("No batches have been recorded.")
@@ -288,6 +272,8 @@ def _get_series_data(self) -> List[_SeriesData]:
sorted_keys = sorted(list(self._variable_metrics[metric].keys()))
for key in sorted_keys:
arr = self._variable_metrics[metric][key].get().detach()
+ if step_slice is not None:
+ arr = arr[step_slice]
datum = _SeriesData(
metric_name=metric,
var_name=key,
@@ -297,18 +283,21 @@ def _get_series_data(self) -> List[_SeriesData]:
return data
@torch.no_grad()
- def get_logs(self, label: str):
+ def get_logs(self, label: str, step_slice: Optional[slice] = None):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
+ step_slice: Slice of forecast steps to log.
"""
logs = {}
series_data: Dict[str, np.ndarray] = {
- datum.get_wandb_key(): datum.data for datum in self._get_series_data()
+ datum.get_wandb_key(): datum.data
+ for datum in self._get_series_data(step_slice)
}
- table = data_to_table(series_data)
+ init_step = 0 if step_slice is None else step_slice.start
+ table = data_to_table(series_data, init_step)
logs[f"{label}/series"] = table
return logs
@@ -319,30 +308,39 @@ def get_dataset(self) -> xr.Dataset:
"""
data_vars = {}
for datum in self._get_series_data():
- metadata = self._metadata.get(
+ metadata = self._variable_metadata.get(
datum.var_name, VariableMetadata("unknown_units", datum.var_name)
)
data_vars[datum.get_xarray_key()] = xr.DataArray(
datum.data, dims=["forecast_step"], attrs=metadata._asdict()
)
- n_forecast_steps = len(next(iter(data_vars.values())))
- coords = {"forecast_step": np.arange(n_forecast_steps)}
+ if len(data_vars.values()) > 0:
+ n_forecast_steps = len(next(iter(data_vars.values())))
+ coords = {"forecast_step": np.arange(n_forecast_steps)}
+ else:
+ coords = {"forecast_step": np.arange(0)}
+
return xr.Dataset(data_vars=data_vars, coords=coords)
-def data_to_table(data: Dict[str, np.ndarray]) -> Table:
+def data_to_table(data: Dict[str, np.ndarray], init_step: int = 0) -> Table:
"""
Convert a dictionary of 1-dimensional timeseries data to a wandb Table.
+
+ Args:
+ data: dictionary of timeseries data.
+ init_step: initial step corresponding to the first row's "forecast_step"
"""
keys = sorted(list(data.keys()))
wandb = WandB.get_instance()
table = wandb.Table(columns=["forecast_step"] + keys)
- for i in range(len(data[keys[0]])):
- row = [i]
- for key in keys:
- row.append(data[key][i])
- table.add_data(*row)
+ if len(keys) > 0:
+ for i in range(len(data[keys[0]])):
+ row = [init_step + i]
+ for key in keys:
+ row.append(data[key][i])
+ table.add_data(*row)
return table
@@ -353,12 +351,10 @@ class AreaWeightedSingleTargetReducedMetric:
def __init__(
self,
- area_weights: torch.Tensor,
device: torch.device,
compute_metric: AreaWeightedSingleTargetFunction,
n_timesteps: int,
):
- self._area_weights = area_weights
self._compute_metric = compute_metric
self._total: Optional[torch.Tensor] = None
self._n_batches = torch.zeros(
@@ -374,9 +370,7 @@ def record(self, tensor: torch.Tensor, i_time_start: int):
tensor: batch data. Should have shape [batch, time, height, width].
i_time_start: The index of the first timestep in the batch.
"""
- new_value = self._compute_metric(
- tensor, weights=self._area_weights, dim=(-2, -1)
- ).mean(dim=0)
+ new_value = self._compute_metric(tensor).mean(dim=0)
if self._total is None:
self._total = torch.zeros(
[self._n_timesteps], dtype=new_value.dtype, device=self._device
@@ -395,11 +389,11 @@ def get(self) -> torch.Tensor:
class SingleTargetMeanAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ gridded_operations: GriddedOperations,
n_timesteps: int,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
- self._area_weights = area_weights
+ self._ops = gridded_operations
self._variable_metrics: Optional[
Dict[str, Dict[str, SingleTargetMeanMetric]]
] = None
@@ -408,10 +402,10 @@ def __init__(
self._n_timesteps = n_timesteps
self._dist = Distributed.get_instance()
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
def _get_variable_metrics(self, gen_data: TensorMapping):
if self._variable_metrics is None:
@@ -421,23 +415,20 @@ def _get_variable_metrics(self, gen_data: TensorMapping):
}
device = get_device()
- for key in gen_data:
- self._variable_metrics["weighted_mean_gen"][
- key
- ] = AreaWeightedSingleTargetReducedMetric(
- area_weights=self._area_weights,
+ self._variable_metrics["weighted_mean_gen"] = defaultdict(
+ lambda: AreaWeightedSingleTargetReducedMetric(
device=device,
- compute_metric=metrics.weighted_mean,
+ compute_metric=lambda x: self._ops.area_weighted_mean(x),
n_timesteps=self._n_timesteps,
)
- self._variable_metrics["weighted_std_gen"][
- key
- ] = AreaWeightedSingleTargetReducedMetric(
- area_weights=self._area_weights,
+ )
+ self._variable_metrics["weighted_std_gen"] = defaultdict(
+ lambda: AreaWeightedSingleTargetReducedMetric(
device=device,
- compute_metric=metrics.weighted_std,
+ compute_metric=lambda x: self._ops.area_weighted_std(x),
n_timesteps=self._n_timesteps,
)
+ )
return self._variable_metrics
@@ -455,7 +446,7 @@ def record_batch(
i_time_start=i_time_start,
)
- def _get_series_data(self) -> List[_SeriesData]:
+ def _get_series_data(self, step_slice: Optional[slice] = None) -> List[_SeriesData]:
"""Converts internally stored variable_metrics to a list."""
if self._variable_metrics is None:
raise ValueError("No batches have been recorded.")
@@ -464,6 +455,8 @@ def _get_series_data(self) -> List[_SeriesData]:
sorted_keys = sorted(list(self._variable_metrics[metric].keys()))
for key in sorted_keys:
arr = self._variable_metrics[metric][key].get().detach()
+ if step_slice is not None:
+ arr = arr[step_slice]
datum = _SeriesData(
metric_name=metric,
var_name=key,
@@ -473,18 +466,21 @@ def _get_series_data(self) -> List[_SeriesData]:
return data
@torch.no_grad()
- def get_logs(self, label: str):
+ def get_logs(self, label: str, step_slice: Optional[slice] = None):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
+ step_slice: Slice of forecast steps to log.
"""
logs = {}
series_data: Dict[str, np.ndarray] = {
- datum.get_wandb_key(): datum.data for datum in self._get_series_data()
+ datum.get_wandb_key(): datum.data
+ for datum in self._get_series_data(step_slice)
}
- table = data_to_table(series_data)
+ init_step = 0 if step_slice is None else step_slice.start
+ table = data_to_table(series_data, init_step)
logs[f"{label}/series"] = table
return logs
@@ -495,7 +491,7 @@ def get_dataset(self) -> xr.Dataset:
"""
data_vars = {}
for datum in self._get_series_data():
- metadata = self._metadata.get(
+ metadata = self._variable_metadata.get(
datum.var_name, VariableMetadata("unknown_units", datum.var_name)
)
data_vars[datum.get_xarray_key()] = xr.DataArray(
diff --git a/fme/fme/core/aggregator/inference/seasonal.py b/fme/fme/ace/aggregator/inference/seasonal.py
similarity index 87%
rename from fme/fme/core/aggregator/inference/seasonal.py
rename to fme/fme/ace/aggregator/inference/seasonal.py
index 6385e8f..522a082 100644
--- a/fme/fme/core/aggregator/inference/seasonal.py
+++ b/fme/fme/ace/aggregator/inference/seasonal.py
@@ -1,25 +1,27 @@
-from typing import Any, Dict, Mapping, Optional, cast
+import logging
+from typing import Any, Dict, Mapping, Optional, Union, cast
import numpy as np
import torch
import xarray as xr
-from fme.core import metrics
-from fme.core.aggregator.plotting import plot_paneled_data
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.aggregator.plotting import plot_paneled_data
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
+from fme.core.wandb import Image
class SeasonalAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ ops: GriddedOperations,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
- self.area_weights = area_weights
- self._metadata = metadata
+ self._area_weighted_mean = ops.area_weighted_mean
+ self._variable_metadata = variable_metadata
self._target_dataset: Optional[xr.Dataset] = None
self._gen_dataset: Optional[xr.Dataset] = None
@@ -84,16 +86,16 @@ def get_logs(self, label: str) -> Dict[str, Any]:
target = cast(xr.Dataset, target / target["counts"]) # type: ignore
gen = cast(xr.Dataset, gen / gen["counts"]) # type: ignore
bias = gen - target
- plots = {}
- metric_logs = {}
+ plots: Dict[str, Image] = {}
+ metric_logs: Dict[str, float] = {}
for name in gen.data_vars.keys():
if name == "counts":
continue
- if self._metadata is not None and name in self._metadata:
- long_name = self._metadata[name].long_name
- units = self._metadata[name].units
+ if self._variable_metadata is not None and name in self._variable_metadata:
+ long_name = self._variable_metadata[name].long_name
+ units = self._variable_metadata[name].units
caption_name = f"{long_name} ({units})"
else:
caption_name = name
@@ -147,10 +149,8 @@ def get_logs(self, label: str) -> Dict[str, Any]:
)
plots[f"bias/{name}"] = image_err
- mse_tensor = metrics.weighted_mean(
+ mse_tensor = self._area_weighted_mean(
torch.as_tensor(bias[name].values ** 2),
- weights=self.area_weights.cpu(),
- dim=(-2, -1),
)
for i, season in enumerate(bias[name].season.values):
rmse = float(mse_tensor[i].sqrt().numpy())
@@ -158,19 +158,24 @@ def get_logs(self, label: str) -> Dict[str, Any]:
rmse = float(
# must compute area mean and then mean across seasons
# before sqrt, so we can't use metrics.root_mean_squared_error
- mse_tensor.mean()
- .sqrt()
- .numpy()
+ mse_tensor.mean().sqrt().numpy()
)
metric_logs[f"time-mean-rmse/{name}"] = rmse
if len(label) > 0:
label = label + "/"
- logs = {}
+ logs: Dict[str, Union[Image, float]] = {}
logs.update({f"{label}{name}": plots[name] for name in plots.keys()})
logs.update({f"{label}{name}": val for name, val in metric_logs.items()})
return logs
+ def get_dataset(self) -> xr.Dataset:
+ logging.debug(
+ "get_dataset not implemented for SeasonalAggregator. "
+ "Returning an empty dataset."
+ )
+ return xr.Dataset()
+
ALL_SEASONS = np.asarray(["DJF", "MAM", "JJA", "SON"])
diff --git a/fme/fme/core/aggregator/inference/spectrum.py b/fme/fme/ace/aggregator/inference/spectrum.py
similarity index 98%
rename from fme/fme/core/aggregator/inference/spectrum.py
rename to fme/fme/ace/aggregator/inference/spectrum.py
index 66334ea..00cc77f 100644
--- a/fme/fme/core/aggregator/inference/spectrum.py
+++ b/fme/fme/ace/aggregator/inference/spectrum.py
@@ -1,3 +1,4 @@
+import logging
import warnings
from collections import defaultdict
from typing import Dict, Optional
@@ -52,7 +53,7 @@ def get_mean(self) -> Dict[str, torch.Tensor]:
class PairedSphericalPowerSpectrumAggregator:
- """Record batches and return plots for paired prediction and target data"""
+ """Record batches and return plots for paired prediction and target data."""
def __init__(self, nlat: int, nlon: int, grid: str = "legendre-gauss"):
self._gen_aggregator = SphericalPowerSpectrumAggregator(nlat, nlon, grid)
@@ -61,7 +62,6 @@ def __init__(self, nlat: int, nlon: int, grid: str = "legendre-gauss"):
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
@@ -91,7 +91,7 @@ def get_logs(self, label: str) -> Dict[str, plt.Figure]:
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
- warnings.warn(
+ logging.debug(
"get_dataset not implemented for PairedSphericalPowerSpectrumAggregator. "
"Returning an empty dataset."
)
diff --git a/fme/fme/core/aggregator/inference/test_annual.py b/fme/fme/ace/aggregator/inference/test_annual.py
similarity index 76%
rename from fme/fme/core/aggregator/inference/test_annual.py
rename to fme/fme/ace/aggregator/inference/test_annual.py
index 3a653fe..92a4b49 100644
--- a/fme/fme/core/aggregator/inference/test_annual.py
+++ b/fme/fme/ace/aggregator/inference/test_annual.py
@@ -8,9 +8,12 @@
import xarray as xr
import fme
-from fme.core.aggregator.inference.annual import GlobalMeanAnnualAggregator
+from fme.ace.aggregator.inference.annual import GlobalMeanAnnualAggregator
+from fme.ace.testing import DimSizes, MonthlyReferenceData
+from fme.core.coordinates import DimSize
from fme.core.device import get_device
-from fme.core.testing import DimSizes, MonthlyReferenceData, mock_distributed
+from fme.core.gridded_ops import LatLonOperations
+from fme.core.testing import mock_distributed
TIMESTEP = datetime.timedelta(hours=6)
@@ -23,20 +26,20 @@ def test_annual_aggregator(tmpdir):
n_time = 365 * 4 * 2
area_weights = torch.ones(n_lat, n_lon).to(fme.get_device())
names = ["a"]
+ horizontal = [DimSize("grid_yt", n_lat), DimSize("grid_xt", n_lon)]
monthly_reference_data = MonthlyReferenceData(
path=pathlib.Path(tmpdir),
names=names,
dim_sizes=DimSizes(
n_time=48,
- n_lat=n_lat,
- n_lon=n_lon,
+ horizontal=horizontal,
nz_interface=1,
),
n_ensemble=3,
)
monthly_ds = xr.open_dataset(monthly_reference_data.data_filename)
agg = GlobalMeanAnnualAggregator(
- area_weights=area_weights,
+ ops=LatLonOperations(area_weights),
timestep=TIMESTEP,
monthly_reference_data=monthly_ds,
)
@@ -77,7 +80,7 @@ def test__get_gathered_means(use_mock_distributed):
n_time = 365 * 4 * 2 # two years, approximately
area_weights = torch.ones(n_lat, n_lon).to(fme.get_device())
agg = GlobalMeanAnnualAggregator(
- area_weights=area_weights,
+ ops=LatLonOperations(area_weights),
timestep=TIMESTEP,
)
target_data = {
@@ -103,11 +106,18 @@ def test__get_gathered_means(use_mock_distributed):
if use_mock_distributed:
world_size = 2
with mock_distributed(world_size=world_size):
- target, gen = agg._get_gathered_means()
+ result = agg._get_gathered_means()
+ assert result is not None
+ target, gen = result
+ combined = agg.get_dataset()
else:
world_size = 1
- target, gen = agg._get_gathered_means()
- for dataset in (target, gen):
- assert set(dataset.dims) == {"sample", "year"}
+ result = agg._get_gathered_means()
+ assert result is not None
+ target, gen = result
+ combined = agg.get_dataset()
+ for dataset in (target, gen, combined):
+ assert set(dataset.dims).issuperset({"sample", "year"})
assert list(dataset.year.values) == [2000, 2001, 2002]
assert dataset.sizes["sample"] == n_sample * world_size
+ assert set(combined.coords["source"].values) == set(["target", "prediction"])
diff --git a/fme/fme/core/aggregator/inference/test_distributed.py b/fme/fme/ace/aggregator/inference/test_distributed.py
similarity index 77%
rename from fme/fme/core/aggregator/inference/test_distributed.py
rename to fme/fme/ace/aggregator/inference/test_distributed.py
index d0985a5..fce2e74 100644
--- a/fme/fme/core/aggregator/inference/test_distributed.py
+++ b/fme/fme/ace/aggregator/inference/test_distributed.py
@@ -1,8 +1,9 @@
import torch
-from fme.core.aggregator.inference.reduced import MeanAggregator
-from fme.core.aggregator.inference.time_mean import TimeMeanEvaluatorAggregator
+from fme.ace.aggregator.inference.reduced import MeanAggregator
+from fme.ace.aggregator.inference.time_mean import TimeMeanEvaluatorAggregator
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from fme.core.testing import mock_distributed
@@ -15,9 +16,11 @@ def test_mean_metrics_call_distributed():
with mock_distributed(-1.0):
data_a = torch.ones([2, 3, 4, 4], device=get_device())
area_weights = torch.ones(1).to(get_device())
- agg = MeanAggregator(area_weights, target="denorm", n_timesteps=3)
+ agg = MeanAggregator(
+ LatLonOperations(area_weights), target="denorm", n_timesteps=3
+ )
sample_data = {"a": data_a}
- agg.record_batch(1.0, sample_data, sample_data, sample_data, sample_data)
+ agg.record_batch(sample_data, sample_data, sample_data, sample_data)
logs = agg.get_logs(label="metrics")
table = logs["metrics/series"]
# assert all data past the first column in the WandB table is -1
@@ -33,11 +36,12 @@ def test_time_mean_metrics_call_distributed():
torch.manual_seed(0)
with mock_distributed(0.0) as mock:
area_weights = torch.ones(1).to(get_device())
- agg = TimeMeanEvaluatorAggregator(area_weights)
+ agg = TimeMeanEvaluatorAggregator(
+ LatLonOperations(area_weights), horizontal_dims=["lat", "lon"]
+ )
target_data = {"a": torch.ones([2, 3, 4, 4], device=get_device())}
gen_data = {"a": torch.randn([2, 3, 4, 4], device=get_device())}
agg.record_batch(
- loss=1.0,
target_data=target_data,
gen_data=gen_data,
target_data_norm=target_data,
diff --git a/fme/fme/ace/aggregator/inference/test_evaluator.py b/fme/fme/ace/aggregator/inference/test_evaluator.py
new file mode 100644
index 0000000..de75db4
--- /dev/null
+++ b/fme/fme/ace/aggregator/inference/test_evaluator.py
@@ -0,0 +1,210 @@
+import datetime
+
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+
+from fme.ace.aggregator.inference import InferenceEvaluatorAggregator
+from fme.ace.data_loading.batch_data import BatchData, PairedData
+from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates
+from fme.core.device import get_device
+
+TIMESTEP = datetime.timedelta(hours=6)
+
+
+def get_zero_time(shape, dims):
+ return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims)
+
+
+def test_logs_labels_exist():
+ n_sample = 10
+ n_time = 22
+ nx = 2
+ ny = 2
+ nz = 3
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ torch.arange(nz + 1), torch.arange(nz + 1)
+ )
+ horizontal_coordinates = LatLonCoordinates(
+ lon=torch.arange(nx),
+ lat=torch.arange(ny),
+ loaded_lon_name="lon",
+ loaded_lat_name="lat",
+ )
+ initial_time = get_zero_time(shape=[n_sample, 0], dims=["sample", "time"])
+
+ agg = InferenceEvaluatorAggregator(
+ vertical_coordinate=vertical_coordinate,
+ horizontal_coordinates=horizontal_coordinates,
+ timestep=TIMESTEP,
+ n_timesteps=n_time,
+ initial_time=initial_time,
+ record_step_20=True,
+ log_video=True,
+ log_zonal_mean_images=True,
+ normalize=lambda x: dict(x),
+ )
+ time = xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"])
+
+ logs = agg.record_batch(
+ data=PairedData(
+ prediction={
+ "a": torch.randn(n_sample, n_time, nx, ny, device=get_device())
+ },
+ target={"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())},
+ time=time,
+ ),
+ )
+ assert len(logs) == n_time
+ expected_step_keys = [
+ "mean/forecast_step",
+ "mean/weighted_mean_gen/a",
+ "mean/weighted_mean_target/a",
+ "mean/weighted_rmse/a",
+ "mean/weighted_std_gen/a",
+ "mean/weighted_bias/a",
+ "mean/weighted_grad_mag_percent_diff/a",
+ "mean_norm/forecast_step",
+ "mean_norm/weighted_mean_gen/a",
+ "mean_norm/weighted_mean_target/a",
+ "mean_norm/weighted_rmse/a",
+ "mean_norm/weighted_std_gen/a",
+ "mean_norm/weighted_bias/a",
+ ]
+ for log in logs:
+ for key in expected_step_keys:
+ assert key in log, key
+ assert len(log) == len(expected_step_keys), set(log).difference(
+ expected_step_keys
+ )
+
+ summary_logs = agg.get_summary_logs()
+ expected_keys = [
+ "mean_step_20/loss",
+ "mean_step_20/weighted_rmse/a",
+ "mean_step_20/weighted_bias/a",
+ "mean_step_20/weighted_grad_mag_percent_diff/a",
+ "spherical_power_spectrum/a",
+ "time_mean/rmse/a",
+ "time_mean/bias/a",
+ "time_mean/bias_map/a",
+ "time_mean/gen_map/a",
+ "time_mean_norm/rmse/a",
+ "time_mean_norm/gen_map/a",
+ "time_mean_norm/rmse/channel_mean",
+ "zonal_mean/error/a",
+ "zonal_mean/gen/a",
+ "video/a",
+ ]
+ for key in expected_keys:
+ assert key in summary_logs, key
+ assert len(summary_logs) == len(expected_keys), set(summary_logs).difference(
+ expected_keys
+ )
+
+
+def test_inference_logs_labels_exist():
+ n_sample = 10
+ n_time = 22
+ nx = 2
+ ny = 2
+ nz = 3
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ torch.arange(nz + 1), torch.arange(nz + 1)
+ )
+ horizontal_coordinates = LatLonCoordinates(
+ lon=torch.arange(nx),
+ lat=torch.arange(ny),
+ loaded_lon_name="lon",
+ loaded_lat_name="lat",
+ )
+ initial_time = (get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]),)
+ agg = InferenceEvaluatorAggregator(
+ vertical_coordinate=vertical_coordinate,
+ horizontal_coordinates=horizontal_coordinates,
+ timestep=TIMESTEP,
+ n_timesteps=n_time,
+ initial_time=initial_time,
+ record_step_20=True,
+ log_video=True,
+ normalize=lambda x: dict(x),
+ )
+ logs = agg.record_batch(
+ data=PairedData(
+ prediction={
+ "a": torch.randn(n_sample, n_time, nx, ny, device=get_device())
+ },
+ target={"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())},
+ time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]),
+ ),
+ )
+ assert isinstance(logs, list)
+ assert len(logs) == n_time
+ assert "mean/weighted_bias/a" in logs[0]
+ assert "mean/weighted_mean_gen/a" in logs[0]
+ assert "mean/weighted_mean_target/a" in logs[0]
+ assert "mean/weighted_grad_mag_percent_diff/a" in logs[0]
+ assert "mean/weighted_rmse/a" in logs[0]
+ assert "mean_norm/weighted_bias/a" in logs[0]
+ assert "mean_norm/weighted_mean_gen/a" in logs[0]
+ assert "mean_norm/weighted_mean_target/a" in logs[0]
+ assert "mean_norm/weighted_rmse/a" in logs[0]
+ # series/table data should be rolled out, not included as a table
+ assert "mean/series" not in logs[0]
+ assert "mean_norm/series" not in logs[0]
+ assert "reduced/series" not in logs[0]
+ assert "reduced_norm/series" not in logs[0]
+
+
+@pytest.mark.parametrize(
+ "window_len, n_windows",
+ [
+ pytest.param(3, 1, id="single_window"),
+ pytest.param(3, 2, id="two_windows"),
+ ],
+)
+def test_inference_logs_length(window_len: int, n_windows: int):
+ """
+ Test that the inference logs are the correct length when using one or more
+ windows.
+ """
+ nz = 3
+ nx, ny = 4, 4
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ torch.arange(nz + 1), torch.arange(nz + 1)
+ )
+ horizontal_coordinates = LatLonCoordinates(
+ lon=torch.arange(nx),
+ lat=torch.arange(ny),
+ loaded_lon_name="lon",
+ loaded_lat_name="lat",
+ )
+ initial_time = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),)
+ agg = InferenceEvaluatorAggregator(
+ vertical_coordinate=vertical_coordinate,
+ horizontal_coordinates=horizontal_coordinates,
+ timestep=TIMESTEP,
+ n_timesteps=window_len * n_windows,
+ initial_time=initial_time,
+ normalize=lambda x: dict(x),
+ )
+ target_data = BatchData.new_on_device(
+ data={"a": torch.zeros([2, window_len, ny, nx], device=get_device())},
+ time=xr.DataArray(np.zeros((2, window_len)), dims=["sample", "time"]),
+ )
+ i_start = 0
+ for i in range(n_windows):
+ sample_data = {"a": torch.zeros([2, window_len, ny, nx], device=get_device())}
+ for i in range(window_len):
+ sample_data["a"][..., i, :, :] = float(i_start + i)
+ paired_data = PairedData.new_on_device(
+ prediction=sample_data,
+ target=target_data.data,
+ time=xr.DataArray(np.zeros((2, window_len)), dims=["sample", "time"]),
+ )
+ logs = agg.record_batch(
+ data=paired_data,
+ )
+ assert len(logs) == window_len
+ i_start += window_len
diff --git a/fme/fme/ace/aggregator/inference/test_inference.py b/fme/fme/ace/aggregator/inference/test_inference.py
new file mode 100644
index 0000000..3faf8dd
--- /dev/null
+++ b/fme/fme/ace/aggregator/inference/test_inference.py
@@ -0,0 +1,106 @@
+import datetime
+
+import numpy as np
+import torch
+import xarray as xr
+
+import fme
+from fme.ace.aggregator.inference import InferenceAggregator
+from fme.ace.data_loading.batch_data import BatchData
+from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
+
+TIMESTEP = datetime.timedelta(hours=6)
+
+
+def get_zero_time(shape, dims):
+ return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims)
+
+
+def test_logs_labels_exist():
+ n_sample = 10
+ n_time = 22
+ nx = 2
+ ny = 2
+ area_weights = torch.ones(ny).to(fme.get_device())
+ agg = InferenceAggregator(
+ LatLonOperations(area_weights),
+ n_timesteps=n_time,
+ )
+ gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
+ time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
+ logs = agg.record_batch(
+ BatchData(data=gen_data, time=time),
+ )
+ assert len(logs) == n_time
+ expected_step_keys = [
+ "mean/forecast_step",
+ "mean/weighted_mean_gen/a",
+ "mean/weighted_std_gen/a",
+ ]
+ for log in logs:
+ for key in expected_step_keys:
+ assert key in log, key
+ assert len(log) == len(expected_step_keys), set(log).difference(
+ expected_step_keys
+ )
+ summary_logs = agg.get_summary_logs()
+ expected_summary_keys = ["time_mean/gen_map/a"]
+ for key in expected_summary_keys:
+ assert key in summary_logs, key
+ assert len(summary_logs) == len(expected_summary_keys), set(
+ summary_logs
+ ).difference(expected_summary_keys)
+
+
+def test_logs_labels_exist_with_reference_time_means():
+ n_sample = 10
+ n_time = 22
+ nx = 2
+ ny = 2
+ area_weights = torch.ones(ny).to(fme.get_device())
+ reference_time_means = xr.Dataset(
+ {
+ "a": xr.DataArray(
+ np.random.randn(ny, nx),
+ dims=["grid_yt", "grid_xt"],
+ )
+ }
+ )
+ agg = InferenceAggregator(
+ LatLonOperations(area_weights),
+ n_timesteps=n_time,
+ time_mean_reference_data=reference_time_means,
+ )
+ gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
+ time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
+ logs = agg.record_batch(
+ BatchData(
+ data=gen_data,
+ time=time,
+ ),
+ )
+ assert len(logs) == n_time
+ expected_step_keys = [
+ "mean/forecast_step",
+ "mean/weighted_mean_gen/a",
+ "mean/weighted_std_gen/a",
+ ]
+ for log in logs:
+ for key in expected_step_keys:
+ assert key in log, key
+ assert len(log) == len(expected_step_keys), set(log).difference(
+ expected_step_keys
+ )
+ summary_logs = agg.get_summary_logs()
+ expected_summary_keys = [
+ "time_mean/gen_map/a",
+ "time_mean/ref_bias_map/a",
+ "time_mean/ref_bias/a",
+ "time_mean/ref_rmse/a",
+ ]
+ for key in expected_summary_keys:
+ assert key in summary_logs, key
+ assert len(summary_logs) == len(expected_summary_keys), set(
+ summary_logs
+ ).difference(expected_summary_keys)
diff --git a/fme/fme/core/aggregator/inference/test_reduced.py b/fme/fme/ace/aggregator/inference/test_reduced.py
similarity index 56%
rename from fme/fme/core/aggregator/inference/test_reduced.py
rename to fme/fme/ace/aggregator/inference/test_reduced.py
index b67fc6d..83f5e45 100644
--- a/fme/fme/core/aggregator/inference/test_reduced.py
+++ b/fme/fme/ace/aggregator/inference/test_reduced.py
@@ -2,11 +2,12 @@
import torch
import fme
-from fme.core.aggregator.inference.reduced import (
+from fme.ace.aggregator.inference.reduced import (
AreaWeightedReducedMetric,
SingleTargetMeanAggregator,
)
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
def test_area_weighted_reduced_metric_includes_later_window_starts():
@@ -19,7 +20,6 @@ def compute_metric(truth, predicted, weights=None, dim=()):
return truth.mean(dim=(2, 3))
metric = AreaWeightedReducedMetric(
- area_weights=torch.ones([4]),
device=get_device(),
compute_metric=compute_metric,
n_timesteps=7,
@@ -49,14 +49,33 @@ def test_single_target_mean_aggregator():
nx = 2
ny = 2
area_weights = torch.ones(ny).to(fme.get_device())
+ torch.manual_seed(0)
agg = SingleTargetMeanAggregator(
- area_weights,
- n_time_per_window * n_window,
+ gridded_operations=LatLonOperations(area_weights),
+ n_timesteps=n_time_per_window * n_window,
)
- data = {"a": torch.randn(n_sample, n_time_per_window, nx, ny, device=get_device())}
+ data_a = torch.randn(n_sample, n_time_per_window, nx, ny, device=get_device())
for i in range(n_window):
- agg.record_batch(data, i_time_start=i * n_time_per_window)
+ data = {"a": data_a[:, i * n_time_per_window : (i + 1) * n_time_per_window]}
+ agg.record_batch(data=data, i_time_start=i * n_time_per_window)
logs = agg.get_logs(label="test")
assert "test/series" in logs
+ ds = agg.get_dataset()
+ for i in range(1, data_a.shape[1]):
+ raw_variable = data_a[:, i]
+ raw_global_mean = raw_variable.mean().cpu().numpy()
+ raw_global_std = (
+ raw_variable.std(dim=(1, 2), correction=0).mean().cpu().numpy()
+ ) # metrics are mean over batch
+ np.testing.assert_allclose(
+ raw_global_std,
+ ds["weighted_std_gen-a"].isel(forecast_step=i).values.item(),
+ rtol=1e-5,
+ )
+ np.testing.assert_allclose(
+ raw_global_mean,
+ ds["weighted_mean_gen-a"].isel(forecast_step=i).values.item(),
+ rtol=1e-5,
+ )
diff --git a/fme/fme/core/aggregator/inference/test_seasonal.py b/fme/fme/ace/aggregator/inference/test_seasonal.py
similarity index 91%
rename from fme/fme/core/aggregator/inference/test_seasonal.py
rename to fme/fme/ace/aggregator/inference/test_seasonal.py
index 7c9fe77..ed1107e 100644
--- a/fme/fme/core/aggregator/inference/test_seasonal.py
+++ b/fme/fme/ace/aggregator/inference/test_seasonal.py
@@ -6,8 +6,9 @@
import xarray as xr
import fme
-from fme.core.aggregator.inference.seasonal import SeasonalAggregator
+from fme.ace.aggregator.inference.seasonal import SeasonalAggregator
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
def get_zero_time(shape, dims):
@@ -23,7 +24,7 @@ def test_seasonal_aggregator():
n_time = int(365 / 10 * 2 / n_time_step + 1) * n_time_step
area_weights = torch.ones(n_lat, n_lon).to(fme.get_device())
agg = SeasonalAggregator(
- area_weights=area_weights,
+ LatLonOperations(area_weights),
)
target_data = {
"a": torch.randn(n_sample, n_time, n_lat, n_lon, device=get_device())
diff --git a/fme/fme/core/aggregator/inference/test_spectrum.py b/fme/fme/ace/aggregator/inference/test_spectrum.py
similarity index 92%
rename from fme/fme/core/aggregator/inference/test_spectrum.py
rename to fme/fme/ace/aggregator/inference/test_spectrum.py
index 75ffc7a..1346c41 100644
--- a/fme/fme/core/aggregator/inference/test_spectrum.py
+++ b/fme/fme/ace/aggregator/inference/test_spectrum.py
@@ -3,7 +3,7 @@
import torch_harmonics
import fme
-from fme.core.aggregator.inference.spectrum import (
+from fme.ace.aggregator.inference.spectrum import (
PairedSphericalPowerSpectrumAggregator,
SphericalPowerSpectrumAggregator,
)
@@ -34,6 +34,6 @@ def test_paired_spherical_power_spectrum_aggregator():
nlon = 16
agg = PairedSphericalPowerSpectrumAggregator(nlat, nlon)
data = {"a": torch.randn(2, 3, nlat, nlon, device=fme.get_device())}
- agg.record_batch(0.0, data, data, None, None)
+ agg.record_batch(data, data, None, None)
result = agg.get_logs("spectrum")
assert isinstance(result["spectrum/a"], plt.Figure)
diff --git a/fme/fme/core/aggregator/inference/test_time_mean.py b/fme/fme/ace/aggregator/inference/test_time_mean.py
similarity index 60%
rename from fme/fme/core/aggregator/inference/test_time_mean.py
rename to fme/fme/ace/aggregator/inference/test_time_mean.py
index e91ef59..b2feef6 100644
--- a/fme/fme/core/aggregator/inference/test_time_mean.py
+++ b/fme/fme/ace/aggregator/inference/test_time_mean.py
@@ -1,17 +1,22 @@
import numpy as np
import torch
-from fme.core.aggregator.inference.time_mean import (
+from fme.ace.aggregator.inference.time_mean import (
TimeMeanAggregator,
TimeMeanEvaluatorAggregator,
)
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
def test_rmse_of_time_mean_all_channels():
torch.manual_seed(0)
area_weights = torch.ones(1).to(get_device())
- agg = TimeMeanEvaluatorAggregator(area_weights, target="norm")
+ agg = TimeMeanEvaluatorAggregator(
+ LatLonOperations(area_weights),
+ horizontal_dims=["lat", "lon"],
+ target="norm",
+ )
target_data_norm = {
"a": torch.ones([2, 3, 4, 4], device=get_device()),
"b": torch.ones([2, 3, 4, 4], device=get_device()) * 3,
@@ -21,7 +26,6 @@ def test_rmse_of_time_mean_all_channels():
"b": torch.ones([2, 3, 4, 4], device=get_device()) * 5,
}
agg.record_batch(
- loss=1.0,
target_data=target_data_norm,
gen_data=gen_data_norm,
target_data_norm=target_data_norm,
@@ -33,9 +37,42 @@ def test_rmse_of_time_mean_all_channels():
assert logs["time_mean_norm/rmse/channel_mean"] == 1.5
+def test_custom_channel_mean_names():
+ torch.manual_seed(0)
+ area_weights = torch.ones(1).to(get_device())
+ agg = TimeMeanEvaluatorAggregator(
+ LatLonOperations(area_weights),
+ horizontal_dims=["lat", "lon"],
+ target="norm",
+ channel_mean_names=["a"],
+ )
+ target_data_norm = {
+ "a": torch.ones([2, 3, 4, 4], device=get_device()),
+ "b": torch.ones([2, 3, 4, 4], device=get_device()) * 3,
+ }
+ gen_data_norm = {
+ "a": torch.ones([2, 3, 4, 4], device=get_device()) * 2.0,
+ "b": torch.ones([2, 3, 4, 4], device=get_device()) * 5,
+ }
+ agg.record_batch(
+ target_data=target_data_norm,
+ gen_data=gen_data_norm,
+ target_data_norm=target_data_norm,
+ gen_data_norm=gen_data_norm,
+ )
+ logs = agg.get_logs(label="time_mean_norm")
+ assert logs["time_mean_norm/rmse/a"] == 1.0
+ assert logs["time_mean_norm/rmse/b"] == 2.0
+ assert logs["time_mean_norm/rmse/channel_mean"] == 1.0
+
+
def test_mean_all_channels_not_in_denorm():
area_weights = torch.ones(1).to(get_device())
- agg = TimeMeanEvaluatorAggregator(area_weights, target="denorm")
+ agg = TimeMeanEvaluatorAggregator(
+ LatLonOperations(area_weights),
+ horizontal_dims=["lat", "lon"],
+ target="denorm",
+ )
target_data = {
"a": torch.ones([2, 3, 4, 4], device=get_device()),
"b": torch.ones([2, 3, 4, 4], device=get_device()) * 3,
@@ -45,7 +82,6 @@ def test_mean_all_channels_not_in_denorm():
"b": torch.ones([2, 3, 4, 4], device=get_device()) * 5,
}
agg.record_batch(
- loss=1.0,
target_data=target_data,
gen_data=gen_data,
target_data_norm=target_data,
@@ -60,16 +96,19 @@ def test_mean_all_channels_not_in_denorm():
def test_bias_values():
area_weights = torch.ones(1).to(get_device())
- agg = TimeMeanEvaluatorAggregator(area_weights, target="denorm")
+ agg = TimeMeanEvaluatorAggregator(
+ LatLonOperations(area_weights),
+ horizontal_dims=["lat", "lon"],
+ target="denorm",
+ )
# use constant values so area-weighting doesn't matter
target_data = {
- "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()),
+ "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()),
}
gen_data = {
- "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()),
+ "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()),
}
agg.record_batch(
- loss=1.0,
target_data=target_data,
gen_data=gen_data,
target_data_norm=target_data,
@@ -93,10 +132,10 @@ def test_bias_values():
def test_aggregator_mean_values():
area_weights = torch.ones(1).to(get_device())
- agg = TimeMeanAggregator(area_weights)
+ agg = TimeMeanAggregator(LatLonOperations(area_weights))
# use constant values so area-weighting doesn't matter
data = {
- "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()),
+ "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()),
}
agg.record_batch(
data=data,
diff --git a/fme/fme/core/aggregator/inference/test_video.py b/fme/fme/ace/aggregator/inference/test_video.py
similarity index 96%
rename from fme/fme/core/aggregator/inference/test_video.py
rename to fme/fme/ace/aggregator/inference/test_video.py
index f82eafe..73a8d80 100644
--- a/fme/fme/core/aggregator/inference/test_video.py
+++ b/fme/fme/ace/aggregator/inference/test_video.py
@@ -4,7 +4,7 @@
import pytest
import torch
-from fme.core.aggregator.inference.video import VideoAggregator
+from fme.ace.aggregator.inference.video import VideoAggregator
from fme.core.device import get_device
from fme.core.typing_ import TensorDict
@@ -82,7 +82,7 @@ def test_video_data(offsets: np.ndarray):
i_end = i_start + n_window_in_memory
target_window, gen_window = time_select(i_start, i_end, gen, target)
aggregator.record_batch(
- loss=0, target_data=target_window, gen_data=gen_window, i_time_start=i_start
+ target_data=target_window, gen_data=gen_window, i_time_start=i_start
)
data = aggregator._get_data()
assert data["bias/a"].target is None
@@ -122,7 +122,7 @@ def test_video_data_without_extended_videos(offsets: np.ndarray):
i_end = i_start + n_window_in_memory
target_window, gen_window = time_select(i_start, i_end, gen, target)
aggregator.record_batch(
- loss=0, target_data=target_window, gen_data=gen_window, i_time_start=i_start
+ target_data=target_window, gen_data=gen_window, i_time_start=i_start
)
data = aggregator._get_data()
assert len(data) == 1
@@ -160,7 +160,6 @@ def test_video_data_values_on_random_inputs(n_batches: int):
target_window, gen_window = time_select(i_start, i_end, gen, target)
for nb in range(n_batches): # shouldn't affect results to duplicate batches
aggregator.record_batch(
- loss=0,
target_data=slice_samples(
target_window,
i_start=nb * samples_per_batch,
diff --git a/fme/fme/core/aggregator/inference/test_zonal_mean.py b/fme/fme/ace/aggregator/inference/test_zonal_mean.py
similarity index 79%
rename from fme/fme/core/aggregator/inference/test_zonal_mean.py
rename to fme/fme/ace/aggregator/inference/test_zonal_mean.py
index 713bf84..946c689 100644
--- a/fme/fme/core/aggregator/inference/test_zonal_mean.py
+++ b/fme/fme/ace/aggregator/inference/test_zonal_mean.py
@@ -1,18 +1,18 @@
import torch
+from fme.ace.aggregator.inference.zonal_mean import ZonalMeanAggregator
from fme.core import get_device
-from fme.core.aggregator.inference.zonal_mean import ZonalMeanAggregator
n_sample, n_time, ny, nx = 3, 6, 10, 20
-loss = 1.0
def test_zonal_mean_dims():
agg = ZonalMeanAggregator(n_timesteps=n_time)
target_data = {"a": torch.randn(n_sample, n_time, ny, nx, device=get_device())}
gen_data = {"a": torch.randn(n_sample, n_time, ny, nx, device=get_device())}
- agg.record_batch(loss, target_data, gen_data, target_data, gen_data, i_time_start=0)
+ agg.record_batch(target_data, gen_data, target_data, gen_data, i_time_start=0)
for data in (agg._target_data, agg._gen_data):
+ assert data is not None
assert data["a"].size() == (
n_sample,
n_time,
@@ -24,10 +24,9 @@ def test_zonal_mean_lat_varying():
agg = ZonalMeanAggregator(n_timesteps=n_time)
arr = torch.arange(ny, dtype=torch.float32, device=get_device())
arr = arr[None, None, :, None].expand(n_sample, n_time, -1, nx)
- agg.record_batch(
- loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0
- )
+ agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0)
for data in (agg._target_data, agg._gen_data):
+ assert data is not None
torch.testing.assert_close(
data["a"][0, 0, :], # one time row of the zonal mean
torch.arange(ny, dtype=torch.float32, device=get_device()),
@@ -38,10 +37,9 @@ def test_zonal_mean_zonally_varying():
agg = ZonalMeanAggregator(n_timesteps=n_time)
arr = torch.arange(nx, dtype=torch.float32, device=get_device())
arr = arr[None, None, None, :].expand(n_sample, n_time, ny, -1)
- agg.record_batch(
- loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0
- )
+ agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0)
for data in (agg._target_data, agg._gen_data):
+ assert data is not None
torch.testing.assert_close(
data["a"][0, 0, :], # one time row of the zonal mean
arr.mean() * torch.ones(ny, dtype=torch.float32, device=get_device()),
@@ -53,13 +51,12 @@ def test_zonal_mean_batch_varying():
for i in range(n_sample): # assume one sample per batch
arr = torch.tensor(i, dtype=torch.float32, device=get_device())
arr = arr[None, None, None, None].expand(-1, n_time, ny, nx)
- agg.record_batch(
- loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0
- )
+ agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0)
for data in (agg._target_data, agg._gen_data):
+ assert data is not None
torch.testing.assert_close(
data["a"].sum(dim=0)[0, 0], # sum over batches, then pick a time/lat point
- torch.arange(n_sample, dtype=torch.float32, device=get_device()).sum()
+ torch.arange(n_sample, dtype=torch.float32, device=get_device()).sum(),
# should be same as sum over batches
)
@@ -72,9 +69,10 @@ def test_zonal_mean_mulitple_time_slices():
arr = torch.arange(ny, dtype=torch.float32, device=get_device())
arr = arr[None, None, :, None].expand(n_sample, n_time_in_memory, ny, nx)
agg.record_batch(
- loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time
+ {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time
)
for data in (agg._target_data, agg._gen_data):
+ assert data is not None
torch.testing.assert_close(
(data["a"] / agg._n_batches)[0, 0, :],
torch.arange(ny, dtype=torch.float32, device=get_device()),
diff --git a/fme/fme/core/aggregator/inference/time_mean.py b/fme/fme/ace/aggregator/inference/time_mean.py
similarity index 74%
rename from fme/fme/core/aggregator/inference/time_mean.py
rename to fme/fme/ace/aggregator/inference/time_mean.py
index ce1335b..01ba79f 100644
--- a/fme/fme/core/aggregator/inference/time_mean.py
+++ b/fme/fme/ace/aggregator/inference/time_mean.py
@@ -5,9 +5,9 @@
import torch
import xarray as xr
-from fme.core import metrics
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Image, WandB
@@ -19,26 +19,27 @@ class _TargetGenPair:
name: str
target: torch.Tensor
gen: torch.Tensor
+ ops: GriddedOperations
def bias(self):
return self.gen - self.target
- def rmse(self, weights: torch.Tensor) -> float:
+ def rmse(self) -> float:
ret = float(
- metrics.root_mean_squared_error(
+ self.ops.area_weighted_rmse(
predicted=self.gen,
truth=self.target,
- weights=weights,
)
.cpu()
.numpy()
)
return ret
- def weighted_mean_bias(self, weights: torch.Tensor) -> float:
+ def weighted_mean_bias(self) -> float:
return float(
- metrics.weighted_mean_bias(
- predicted=self.gen, truth=self.target, weights=weights
+ self.ops.area_weighted_mean_bias(
+ predicted=self.gen,
+ truth=self.target,
)
.cpu()
.numpy()
@@ -58,27 +59,27 @@ class TimeMeanAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ gridded_operations: GriddedOperations,
target: Literal["norm", "denorm"] = "denorm",
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
reference_means: Optional[xr.Dataset] = None,
):
"""
Args:
- area_weights: Area weights for each grid cell.
+ gridded_operations: Computes gridded operations.
target: Whether to compute metrics on the normalized or denormalized data,
defaults to "denorm".
- metadata: Mapping of variable names their metadata that will
+ variable_metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
reference_means: Dataset containing reference time-mean values
for bias computation.
"""
- self._area_weights = area_weights
+ self._ops = gridded_operations
self._target = target
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
# Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
self._data: Optional[TensorDict] = None
self._n_timesteps = 0
@@ -166,6 +167,7 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]:
self._reference_means[name].values, device=pred.device
),
gen=pred,
+ ops=self._ops,
)
bias_map = pair.bias().cpu().numpy()
vmin_bias, vmax_bias = get_cmap_limits(bias_map, diverging=True)
@@ -176,10 +178,8 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]:
bias_fig,
caption=self._get_caption("bias_map", name, vmin_bias, vmax_bias),
)
- logs.update(
- {f"ref_bias/{name}": pair.weighted_mean_bias(self._area_weights)}
- )
- logs.update({f"ref_rmse/{name}": pair.rmse(self._area_weights)})
+ logs.update({f"ref_bias/{name}": pair.weighted_mean_bias()})
+ logs.update({f"ref_rmse/{name}": pair.rmse()})
logs.update({f"ref_bias_map/{name}": bias_image})
if len(label) != 0:
@@ -187,9 +187,9 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]:
return logs
def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str:
- if name in self._metadata:
- caption_name = self._metadata[name].long_name
- units = self._metadata[name].units
+ if name in self._variable_metadata:
+ caption_name = self._variable_metadata[name].long_name
+ units = self._variable_metadata[name].units
else:
caption_name, units = name, "unknown_units"
caption = self._image_captions[key].format(name=caption_name, units=units)
@@ -200,9 +200,9 @@ def get_dataset(self) -> xr.Dataset:
dims = ("lat", "lon")
data = {}
for name, pred in self.get_data().items():
- if name in self._metadata:
- long_name = self._metadata[name].long_name
- units = self._metadata[name].units
+ if name in self._variable_metadata:
+ long_name = self._variable_metadata[name].long_name
+ units = self._variable_metadata[name].units
else:
long_name = name
units = "unknown_units"
@@ -233,43 +233,49 @@ class TimeMeanEvaluatorAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ ops: GriddedOperations,
+ horizontal_dims: List[str],
target: Literal["norm", "denorm"] = "denorm",
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
reference_means: Optional[xr.Dataset] = None,
+ channel_mean_names: Optional[List[str]] = None,
):
"""
Args:
- area_weights: Area weights for each grid cell.
+ ops: Computes gridded operations.
+ horizontal_dims: Names of the horizontal dimensions.
target: Whether to compute metrics on the normalized or denormalized data,
defaults to "denorm".
- metadata: Mapping of variable names their metadata that will
+ variable_metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
reference_means: Dataset containing reference time-mean values
for bias computation.
+ channel_mean_names: Names of variables whose RMSE will be averaged. If
+ not provided, all available variables will be used.
"""
- self._area_weights = area_weights
+ self._ops = ops
+ self._horizontal_dims = horizontal_dims
self._target = target
self._dist = Distributed.get_instance()
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
# Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
self._target_agg = TimeMeanAggregator(
- area_weights=area_weights, target=target, metadata=metadata
+ gridded_operations=ops, target=target, variable_metadata=variable_metadata
)
self._gen_agg = TimeMeanAggregator(
- area_weights=area_weights,
+ gridded_operations=ops,
target=target,
- metadata=metadata,
+ variable_metadata=variable_metadata,
reference_means=reference_means,
)
+ self._channel_mean_names = channel_mean_names
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
@@ -289,7 +295,12 @@ def _get_target_gen_pairs(self) -> List[_TargetGenPair]:
ret = []
for name in gen_data.keys():
ret.append(
- _TargetGenPair(gen=gen_data[name], target=target_data[name], name=name)
+ _TargetGenPair(
+ gen=gen_data[name],
+ target=target_data[name],
+ name=name,
+ ops=self._ops,
+ )
)
return ret
@@ -313,33 +324,33 @@ def get_logs(self, label: str) -> Dict[str, Union[float, torch.Tensor, Image]]:
),
)
plt.close("all")
- rmse_all_channels[pred.name] = pred.rmse(weights=self._area_weights)
+ rmse_all_channels[pred.name] = pred.rmse()
logs.update({f"rmse/{pred.name}": rmse_all_channels[pred.name]})
if self._target == "denorm":
logs.update(
{
f"{bias_map_key}/{pred.name}": bias_image,
- f"bias/{pred.name}": pred.weighted_mean_bias(
- weights=self._area_weights
- ),
+ f"bias/{pred.name}": pred.weighted_mean_bias(),
}
)
if self._target == "norm":
- logs.update(
- {
- f"rmse/channel_mean": sum(rmse_all_channels.values())
- / len(rmse_all_channels),
- }
- )
+ metric_name = "rmse/channel_mean"
+ if self._channel_mean_names is None:
+ values_to_average = list(rmse_all_channels.values())
+ else:
+ values_to_average = [
+ rmse_all_channels[name] for name in self._channel_mean_names
+ ]
+ logs.update({metric_name: sum(values_to_average) / len(values_to_average)})
if len(label) != 0:
return {f"{label}/{key}": logs[key] for key in logs}
return logs
def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str:
- if name in self._metadata:
- caption_name = self._metadata[name].long_name
- units = self._metadata[name].units
+ if name in self._variable_metadata:
+ caption_name = self._variable_metadata[name].long_name
+ units = self._variable_metadata[name].units
else:
caption_name, units = name, "unknown_units"
caption = self._image_captions[key].format(name=caption_name, units=units)
@@ -349,27 +360,27 @@ def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str:
def get_dataset(self) -> xr.Dataset:
data = {}
preds = self._get_target_gen_pairs()
- dims = ("lat", "lon")
for pred in preds:
- if pred.name in self._metadata:
- long_name = self._metadata[pred.name].long_name
- units = self._metadata[pred.name].units
+ if pred.name in self._variable_metadata:
+ long_name = self._variable_metadata[pred.name].long_name
+ units = self._variable_metadata[pred.name].units
else:
long_name = pred.name
units = "unknown_units"
gen_metadata = VariableMetadata(long_name=long_name, units=units)._asdict()
- bias_metadata = self._metadata.get(
+ bias_metadata = self._variable_metadata.get(
pred.name, VariableMetadata(long_name=long_name, units=units)
)._asdict()
- gen_metadata = VariableMetadata(long_name=long_name, units=units)._asdict()
data.update(
{
f"bias_map-{pred.name}": xr.DataArray(
- pred.bias().cpu(), dims=dims, attrs=bias_metadata
+ pred.bias().cpu(),
+ dims=self._horizontal_dims,
+ attrs=bias_metadata,
),
f"gen_map-{pred.name}": xr.DataArray(
pred.gen.cpu(),
- dims=dims,
+ dims=self._horizontal_dims,
attrs=gen_metadata,
),
}
diff --git a/fme/fme/core/aggregator/inference/video.py b/fme/fme/ace/aggregator/inference/video.py
similarity index 95%
rename from fme/fme/core/aggregator/inference/video.py
rename to fme/fme/ace/aggregator/inference/video.py
index 0bb17e8..24090a8 100644
--- a/fme/fme/core/aggregator/inference/video.py
+++ b/fme/fme/ace/aggregator/inference/video.py
@@ -5,7 +5,7 @@
import torch
import xarray as xr
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import WandB
@@ -216,9 +216,7 @@ def record_batch(
time_slice = slice(i_time_start, i_time_start + window_steps)
for name, tensor in target_data.items():
self._target_means[name][time_slice, ...] += tensor.mean(dim=0).cpu()
- self._target_squares[name][time_slice, ...] += (
- (tensor**2).mean(dim=0).cpu()
- )
+ self._target_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu()
for name, tensor in gen_data.items():
self._gen_means[name][time_slice, ...] += tensor.mean(dim=0).cpu()
self._gen_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu()
@@ -292,20 +290,20 @@ def __init__(
self,
n_timesteps: int,
enable_extended_videos: bool,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
"""
Args:
n_timesteps: Number of timesteps of inference that will be run.
enable_extended_videos: Whether to log videos of statistical
metrics of state evolution
- metadata: Mapping of variable names their metadata that will
+ variable_metadata: Mapping of variable names their metadata that will
used in generating logged video captions.
"""
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
self._mean_data = _MeanVideoData(n_timesteps=n_timesteps)
if enable_extended_videos:
self._error_data: Optional[_ErrorVideoData] = _ErrorVideoData(
@@ -323,7 +321,6 @@ def __init__(
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: Optional[TensorMapping] = None,
@@ -375,14 +372,14 @@ def _get_data(self) -> Mapping[str, _MaybePairedVideoData]:
video_data = {}
def get_units(name: str) -> Optional[str]:
- if name in self._metadata:
- return self._metadata[name].units
+ if name in self._variable_metadata:
+ return self._variable_metadata[name].units
else:
return None
def get_long_name(name: str) -> Optional[str]:
- if name in self._metadata:
- return self._metadata[name].long_name
+ if name in self._variable_metadata:
+ return self._variable_metadata[name].long_name
else:
return None
@@ -480,9 +477,9 @@ def _get_caption(self, name: str) -> str:
caption = (
"Autoregressive (left) prediction and (right) target for {name} [{units}]"
)
- if name in self._metadata:
- caption_name = self._metadata[name].long_name
- units = self._metadata[name].units
+ if name in self._variable_metadata:
+ caption_name = self._variable_metadata[name].long_name
+ units = self._variable_metadata[name].units
else:
caption_name, units = name, "unknown units"
return caption.format(name=caption_name, units=units)
diff --git a/fme/fme/core/aggregator/inference/zonal_mean.py b/fme/fme/ace/aggregator/inference/zonal_mean.py
similarity index 86%
rename from fme/fme/core/aggregator/inference/zonal_mean.py
rename to fme/fme/ace/aggregator/inference/zonal_mean.py
index 87ae6a8..5c6de1a 100644
--- a/fme/fme/core/aggregator/inference/zonal_mean.py
+++ b/fme/fme/ace/aggregator/inference/zonal_mean.py
@@ -6,7 +6,7 @@
import torch
import xarray as xr
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict, TensorMapping
@@ -61,32 +61,29 @@ class ZonalMeanAggregator:
def __init__(
self,
n_timesteps: int,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
):
"""
Args:
n_timesteps: Number of timesteps of inference that will be run.
- metadata: Mapping of variable names their metadata that will
+ variable_metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
"""
self._n_timesteps = n_timesteps
self._dist = Distributed.get_instance()
- if metadata is None:
- self._metadata: Mapping[str, VariableMetadata] = {}
+ if variable_metadata is None:
+ self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
- self._metadata = metadata
+ self._variable_metadata = variable_metadata
self._target_data: Optional[TensorDict] = None
self._gen_data: Optional[TensorDict] = None
self._n_batches = torch.zeros(
n_timesteps, dtype=torch.int32, device=get_device()
- )[
- None, :, None
- ] # sample, time, lat
+ )[None, :, None] # sample, time, lat
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
@@ -107,9 +104,11 @@ def record_batch(
time_slice = slice(i_time_start, i_time_start + window_steps)
# we can average along longitude without area weighting
for name, tensor in target_data.items():
- self._target_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim)
+ if name in self._target_data:
+ self._target_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim)
for name, tensor in gen_data.items():
- self._gen_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim)
+ if name in self._gen_data:
+ self._gen_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim)
self._n_batches[:, time_slice, :] += 1
def _get_data(self) -> Dict[str, _RawData]:
@@ -134,7 +133,9 @@ def _get_data(self) -> Dict[str, _RawData]:
.numpy()
)
- metadata = self._metadata.get(name, VariableMetadata("unknown_units", name))
+ metadata = self._variable_metadata.get(
+ name, VariableMetadata("unknown_units", name)
+ )
vmin, vmax = get_cmap_limits(gen)
data[f"gen/{name}"] = _RawData(
datum=gen,
@@ -173,9 +174,9 @@ def get_dataset(self) -> xr.Dataset:
return ret
def _get_caption(self, key: str, varname: str, vmin: float, vmax: float) -> str:
- if varname in self._metadata:
- caption_name = self._metadata[varname].long_name
- units = self._metadata[varname].units
+ if varname in self._variable_metadata:
+ caption_name = self._variable_metadata[varname].long_name
+ units = self._variable_metadata[varname].units
else:
caption_name, units = varname, "unknown_units"
caption = self._captions[key].format(name=caption_name, units=units)
diff --git a/fme/fme/core/aggregator/null.py b/fme/fme/ace/aggregator/null.py
similarity index 100%
rename from fme/fme/core/aggregator/null.py
rename to fme/fme/ace/aggregator/null.py
diff --git a/fme/fme/core/aggregator/one_step/__init__.py b/fme/fme/ace/aggregator/one_step/__init__.py
similarity index 100%
rename from fme/fme/core/aggregator/one_step/__init__.py
rename to fme/fme/ace/aggregator/one_step/__init__.py
diff --git a/fme/fme/core/aggregator/one_step/main.py b/fme/fme/ace/aggregator/one_step/main.py
similarity index 66%
rename from fme/fme/core/aggregator/one_step/main.py
rename to fme/fme/ace/aggregator/one_step/main.py
index 23d6d0a..18a5121 100644
--- a/fme/fme/core/aggregator/one_step/main.py
+++ b/fme/fme/ace/aggregator/one_step/main.py
@@ -1,9 +1,12 @@
-from typing import Mapping, Optional, Protocol
+from typing import Dict, Mapping, Optional, Protocol
+import numpy as np
import torch
-from fme.core.aggregator.one_step.derived import DerivedMetricsAggregator
-from fme.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata
+from fme.ace.stepper import TrainOutput
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.generics.aggregator import AggregatorABC
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from .map import MapAggregator
@@ -12,8 +15,7 @@
class _Aggregator(Protocol):
- def get_logs(self, label: str) -> TensorMapping:
- ...
+ def get_logs(self, label: str) -> TensorMapping: ...
def record_batch(
self,
@@ -22,11 +24,10 @@ def record_batch(
gen_data: TensorMapping,
target_data_norm: TensorMapping,
gen_data_norm: TensorMapping,
- ) -> None:
- ...
+ ) -> None: ...
-class OneStepAggregator:
+class OneStepAggregator(AggregatorABC[TrainOutput]):
"""
Aggregates statistics for the first timestep.
@@ -36,46 +37,42 @@ class OneStepAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ gridded_operations: GriddedOperations,
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
loss_scaling: Optional[TensorMapping] = None,
):
"""
Args:
- area_weights: Weights for each horizontal grid coordinate
- sigma_coordinates: Coordinates for defining pressure levels.
- metadata: Metadata for each variable.
+ gridded_operations: Operations for computing metrics on gridded data.
+ variable_metadata: Metadata for each variable.
loss_scaling: Dictionary of variables and their scaling factors
used in loss computation.
"""
- self._aggregators: Mapping[str, _Aggregator] = {
- "snapshot": SnapshotAggregator(metadata),
- "mean": MeanAggregator(area_weights),
- "derived": DerivedMetricsAggregator(area_weights, sigma_coordinates),
- "mean_map": MapAggregator(metadata),
+ aggregators: Dict[str, _Aggregator] = {
+ "mean": MeanAggregator(gridded_operations)
}
+ aggregators["snapshot"] = SnapshotAggregator(variable_metadata)
+ aggregators["mean_map"] = MapAggregator(variable_metadata)
+ self._aggregators = aggregators
self._loss_scaling = loss_scaling or {}
@torch.no_grad()
def record_batch(
self,
- loss: float,
- target_data: TensorMapping,
- gen_data: TensorMapping,
- target_data_norm: TensorMapping,
- gen_data_norm: TensorMapping,
+ batch: TrainOutput,
):
- if len(target_data) == 0:
+ if len(batch.target_data) == 0:
raise ValueError("No data in target_data")
- if len(gen_data) == 0:
+ if len(batch.gen_data) == 0:
raise ValueError("No data in gen_data")
+ gen_data_norm = batch.normalize(batch.gen_data)
+ target_data_norm = batch.normalize(batch.target_data)
for agg in self._aggregators.values():
agg.record_batch(
- loss=loss,
- target_data=target_data,
- gen_data=gen_data,
+ loss=batch.metrics.get("loss", np.nan),
+ target_data=batch.target_data,
+ gen_data=batch.gen_data,
target_data_norm=target_data_norm,
gen_data_norm=gen_data_norm,
)
diff --git a/fme/fme/core/aggregator/one_step/map.py b/fme/fme/ace/aggregator/one_step/map.py
similarity index 98%
rename from fme/fme/core/aggregator/one_step/map.py
rename to fme/fme/ace/aggregator/one_step/map.py
index 2adb564..81b4812 100644
--- a/fme/fme/core/aggregator/one_step/map.py
+++ b/fme/fme/ace/aggregator/one_step/map.py
@@ -2,7 +2,7 @@
import torch
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Image
diff --git a/fme/fme/core/aggregator/one_step/reduced.py b/fme/fme/ace/aggregator/one_step/reduced.py
similarity index 76%
rename from fme/fme/core/aggregator/one_step/reduced.py
rename to fme/fme/ace/aggregator/one_step/reduced.py
index beb9932..6ce35a2 100644
--- a/fme/fme/core/aggregator/one_step/reduced.py
+++ b/fme/fme/ace/aggregator/one_step/reduced.py
@@ -1,11 +1,12 @@
-from typing import Dict, Optional, Union
+from typing import Dict, Optional
+import numpy as np
import torch
import xarray as xr
-from fme.core import metrics
from fme.core.device import get_device
from fme.core.distributed import Distributed
+from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from .reduced_metrics import AreaWeightedReducedMetric, ReducedMetric
@@ -28,12 +29,10 @@ class MeanAggregator:
def __init__(
self,
- area_weights: torch.Tensor,
+ gridded_operations: GriddedOperations,
target_time: int = 1,
):
- self._area_weights = area_weights
- self._shape_x = None
- self._shape_y = None
+ self._gridded_operations = gridded_operations
self._n_batches = 0
self._loss = torch.tensor(0.0, device=get_device())
self._variable_metrics: Optional[Dict[str, Dict[str, ReducedMetric]]] = None
@@ -49,37 +48,35 @@ def _get_variable_metrics(self, gen_data: TensorMapping):
}
device = get_device()
for key in gen_data:
- self._variable_metrics["weighted_rmse"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
- device=device,
- compute_metric=metrics.root_mean_squared_error,
+ self._variable_metrics["weighted_rmse"][key] = (
+ AreaWeightedReducedMetric(
+ device=device,
+ compute_metric=self._gridded_operations.area_weighted_rmse,
+ )
)
- self._variable_metrics["weighted_bias"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
- device=device,
- compute_metric=metrics.weighted_mean_bias,
+ self._variable_metrics["weighted_bias"][key] = (
+ AreaWeightedReducedMetric(
+ device=device,
+ compute_metric=self._gridded_operations.area_weighted_mean_bias,
+ )
)
- self._variable_metrics["weighted_grad_mag_percent_diff"][
- key
- ] = AreaWeightedReducedMetric(
- area_weights=self._area_weights,
- device=device,
- compute_metric=metrics.gradient_magnitude_percent_diff,
+ self._variable_metrics["weighted_grad_mag_percent_diff"][key] = (
+ AreaWeightedReducedMetric(
+ device=device,
+ compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff, # noqa: E501
+ )
)
+
return self._variable_metrics
@torch.no_grad()
def record_batch(
self,
- loss: float,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping,
gen_data_norm: TensorMapping,
+ loss: torch.Tensor = torch.tensor(np.nan),
i_time_start: int = 0,
):
self._loss += loss
@@ -102,9 +99,7 @@ def record_batch(
def _get_data(self):
if self._variable_metrics is None or self._n_batches == 0:
raise ValueError("No batches have been recorded.")
- data: Dict[str, Union[float, torch.Tensor]] = {
- "loss": self._loss / self._n_batches
- }
+ data: Dict[str, torch.Tensor] = {"loss": self._loss / self._n_batches}
for metric in self._variable_metrics:
for key in self._variable_metrics[metric]:
data[f"{metric}/{key}"] = (
diff --git a/fme/fme/core/aggregator/one_step/reduced_metrics.py b/fme/fme/ace/aggregator/one_step/reduced_metrics.py
similarity index 82%
rename from fme/fme/core/aggregator/one_step/reduced_metrics.py
rename to fme/fme/ace/aggregator/one_step/reduced_metrics.py
index d821307..3957b3d 100644
--- a/fme/fme/core/aggregator/one_step/reduced_metrics.py
+++ b/fme/fme/ace/aggregator/one_step/reduced_metrics.py
@@ -4,11 +4,10 @@
to turn metric functions that may have different APIs into a common API,
so that they can be iterated over and called in the same way in a loop.
"""
-from typing import Optional, Protocol
-import torch
+from typing import Protocol
-from fme.core.metrics import Dimension
+import torch
class AreaWeightedFunction(Protocol):
@@ -21,10 +20,7 @@ def __call__(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
- weights: Optional[torch.Tensor] = None,
- dim: Dimension = (),
- ) -> torch.Tensor:
- ...
+ ) -> torch.Tensor: ...
class ReducedMetric(Protocol):
@@ -52,11 +48,9 @@ class AreaWeightedReducedMetric:
def __init__(
self,
- area_weights: torch.Tensor,
device: torch.device,
compute_metric: AreaWeightedFunction,
):
- self._area_weights = area_weights
self._compute_metric = compute_metric
self._total = None
self._device = device
@@ -68,9 +62,7 @@ def record(self, target: torch.Tensor, gen: torch.Tensor):
target: Target data. Should have shape [batch, time, height, width].
gen: Generated data. Should have shape [batch, time, height, width].
"""
- new_value = self._compute_metric(
- target, gen, weights=self._area_weights, dim=(-2, -1)
- ).mean(dim=0)
+ new_value = self._compute_metric(target, gen).mean(dim=0)
if self._total is None:
self._total = torch.zeros_like(new_value, device=self._device)
self._total += new_value
diff --git a/fme/fme/core/aggregator/one_step/snapshot.py b/fme/fme/ace/aggregator/one_step/snapshot.py
similarity index 67%
rename from fme/fme/core/aggregator/one_step/snapshot.py
rename to fme/fme/ace/aggregator/one_step/snapshot.py
index 90f0bae..d066c44 100644
--- a/fme/fme/core/aggregator/one_step/snapshot.py
+++ b/fme/fme/ace/aggregator/one_step/snapshot.py
@@ -1,13 +1,12 @@
from typing import Dict, Mapping, Optional
-import matplotlib.pyplot as plt
import torch
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.typing_ import TensorMapping
-from fme.core.wandb import Image, WandB
+from fme.core.wandb import Image
-from ..plotting import get_cmap_limits, plot_imshow
+from ..plotting import plot_paneled_data
class SnapshotAggregator:
@@ -18,11 +17,11 @@ class SnapshotAggregator:
_captions = {
"full-field": (
"{name} one step full field for last sample; "
- "(left) generated and (right) target [{units}]"
+ "(top) generated and (bottom) target [{units}]"
),
"residual": (
"{name} one step residual (prediction - previous time) for last sample; "
- "(left) generated and (right) target [{units}]"
+ "(top) generated and (bottom) target [{units}]"
),
"error": (
"{name} one step full field error (generated - target) "
@@ -68,7 +67,6 @@ def get_logs(self, label: str) -> Dict[str, Image]:
input_time = 0
target_time = 1
image_logs = {}
- wandb = WandB.get_instance()
for name in self._gen_data.keys():
# use first sample in batch
gen = self._gen_data[name].select(dim=time_dim, index=target_time)[0].cpu()
@@ -78,42 +76,26 @@ def get_logs(self, label: str) -> Dict[str, Image]:
input = (
self._target_data[name].select(dim=time_dim, index=input_time)[0].cpu()
)
- gap_shape = (input.shape[-2], 4)
- gap = torch.full(gap_shape, target.min())
- gap_res = torch.full(gap_shape, (target - input).min())
images = {}
- images["error"] = (gen - target).numpy()
- images["full-field"] = torch.cat((gen, gap, target), axis=1).numpy()
- images["residual"] = torch.cat(
- (
- gen - input,
- gap_res,
- target - input,
- ),
- axis=1,
- ).numpy()
+ images["error"] = [[(gen - target).numpy()]]
+ images["full-field"] = [[gen.numpy()], [target.numpy()]]
+ images["residual"] = [[(gen - input).numpy()], [(target - input).numpy()]]
for key, data in images.items():
if key == "error" or key == "residual":
diverging = True
- cmap = "RdBu_r"
else:
diverging = False
- cmap = None
- vmin, vmax = get_cmap_limits(data, diverging=diverging)
- caption = self._get_caption(key, name, vmin, vmax)
- fig = plot_imshow(data, vmin=vmin, vmax=vmax, cmap=cmap)
- wandb_image = wandb.Image(fig, caption=caption)
- plt.close(fig)
+ caption = self._get_caption(key, name)
+ wandb_image = plot_paneled_data(data, diverging, caption=caption)
image_logs[f"image-{key}/{name}"] = wandb_image
image_logs = {f"{label}/{key}": image_logs[key] for key in image_logs}
return image_logs
- def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str:
+ def _get_caption(self, key: str, name: str) -> str:
if name in self._metadata:
caption_name = self._metadata[name].long_name
units = self._metadata[name].units
else:
caption_name, units = name, "unknown_units"
caption = self._captions[key].format(name=caption_name, units=units)
- caption += f" vmin={vmin:.4g}, vmax={vmax:.4g}."
return caption
diff --git a/fme/fme/ace/aggregator/one_step/test_main.py b/fme/fme/ace/aggregator/one_step/test_main.py
new file mode 100644
index 0000000..84b4ddc
--- /dev/null
+++ b/fme/fme/ace/aggregator/one_step/test_main.py
@@ -0,0 +1,88 @@
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+
+from fme.ace.aggregator.one_step import OneStepAggregator
+from fme.ace.stepper import TrainOutput
+from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
+
+
+def test_labels_exist():
+ n_sample = 10
+ n_time = 3
+ nx, ny = 2, 2
+ loss = 1.0
+ area_weights = torch.ones(ny).to(get_device())
+ agg = OneStepAggregator(LatLonOperations(area_weights))
+ target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
+ gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
+ agg.record_batch(
+ batch=TrainOutput(
+ metrics={"loss": loss},
+ target_data=target_data,
+ gen_data=gen_data,
+ time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]),
+ normalize=lambda x: x,
+ ),
+ )
+ logs = agg.get_logs(label="test")
+ assert "test/mean/loss" in logs
+ assert "test/mean/weighted_rmse/a" in logs
+ assert "test/mean/weighted_bias/a" in logs
+ assert "test/mean/weighted_grad_mag_percent_diff/a" in logs
+ assert "test/snapshot/image-full-field/a" in logs
+ assert "test/snapshot/image-residual/a" in logs
+ assert "test/snapshot/image-error/a" in logs
+
+
+def test_aggregator_raises_on_no_data():
+ """
+ Basic test the aggregator combines loss correctly
+ with multiple batches and no distributed training.
+ """
+ ny = 2
+ area_weights = torch.ones(ny).to(get_device())
+ agg = OneStepAggregator(LatLonOperations(area_weights))
+ with pytest.raises(ValueError) as excinfo:
+ agg.record_batch(
+ batch=TrainOutput(
+ metrics={"loss": 1.0},
+ target_data={},
+ gen_data={},
+ time=xr.DataArray(np.zeros((0, 0)), dims=["sample", "time"]),
+ normalize=lambda x: x,
+ ),
+ )
+ # check that the raised exception contains the right substring
+ assert "No data" in str(excinfo.value)
+
+
+def test__get_loss_scaled_mse_components():
+ loss_scaling = {
+ "a": torch.tensor(1.0),
+ "b": torch.tensor(0.5),
+ }
+ agg = OneStepAggregator(
+ gridded_operations=LatLonOperations(
+ area_weights=torch.ones(10).to(get_device())
+ ),
+ loss_scaling=loss_scaling,
+ )
+
+ logs = {
+ "test/mean/weighted_rmse/a": 1.0,
+ "test/mean/weighted_rmse/b": 4.0,
+ "test/mean/weighted_rmse/c": 0.0,
+ }
+ result = agg._get_loss_scaled_mse_components(logs, "test")
+ scaled_squared_errors_sum = (1.0 / 1.0) ** 2 + (4.0 / 0.5) ** 2
+ assert (
+ result["test/mean/mse_fractional_components/a"] == 1 / scaled_squared_errors_sum
+ )
+ assert (
+ result["test/mean/mse_fractional_components/b"]
+ == 64 / scaled_squared_errors_sum
+ )
+ assert "test/mean/mse_fractional_components/c" not in result
diff --git a/fme/fme/core/aggregator/one_step/test_reduced.py b/fme/fme/ace/aggregator/one_step/test_reduced.py
similarity index 62%
rename from fme/fme/core/aggregator/one_step/test_reduced.py
rename to fme/fme/ace/aggregator/one_step/test_reduced.py
index cbdb10e..25b8960 100644
--- a/fme/fme/core/aggregator/one_step/test_reduced.py
+++ b/fme/fme/ace/aggregator/one_step/test_reduced.py
@@ -2,8 +2,9 @@
import pytest
import torch
-from fme.core.aggregator.one_step.reduced import MeanAggregator
+from fme.ace.aggregator.one_step.reduced import MeanAggregator
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from fme.core.testing import mock_distributed
@@ -15,9 +16,15 @@ def test_mean_metrics_call_distributed():
"""
with mock_distributed(-1.0) as mock:
area_weights = torch.ones([4]).to(get_device())
- agg = MeanAggregator(area_weights)
+ agg = MeanAggregator(LatLonOperations(area_weights))
sample_data = {"a": torch.ones([2, 3, 4, 4], device=get_device())}
- agg.record_batch(1.0, sample_data, sample_data, sample_data, sample_data)
+ agg.record_batch(
+ loss=1.0,
+ target_data=sample_data,
+ gen_data=sample_data,
+ target_data_norm=sample_data,
+ gen_data_norm=sample_data,
+ )
logs = agg.get_logs(label="metrics")
assert logs["metrics/loss"] == -1.0
assert logs["metrics/weighted_rmse/a"] == -1.0
@@ -30,14 +37,14 @@ def test_i_time_start_gets_correct_time_one_step_windows():
# the data from the correct timestep is piped into the aggregator.
target_time = 3
area_weights = torch.ones([4]).to(get_device())
- agg = MeanAggregator(area_weights, target_time=target_time)
+ agg = MeanAggregator(LatLonOperations(area_weights), target_time=target_time)
target_data = {"a": torch.zeros([2, 1, 4, 4], device=get_device())}
for i in range(5):
sample_data = {
"a": torch.full([2, 1, 4, 4], fill_value=float(i), device=get_device())
}
agg.record_batch(
- 1.0,
+ loss=1.0,
target_data=target_data,
gen_data=sample_data,
target_data_norm=target_data,
@@ -62,7 +69,7 @@ def test_i_time_start_gets_correct_time_longer_windows(
# while this directly tests the "mean" result, this is really a test that
# the data from the correct timestep is piped into the aggregator.
area_weights = torch.ones([4]).to(get_device())
- agg = MeanAggregator(area_weights, target_time=target_time)
+ agg = MeanAggregator(LatLonOperations(area_weights), target_time=target_time)
target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())}
i_start = 0
for i in range(n_windows):
@@ -70,7 +77,7 @@ def test_i_time_start_gets_correct_time_longer_windows(
for i in range(window_len):
sample_data["a"][..., i, :, :] = float(i_start + i)
agg.record_batch(
- 1.0,
+ loss=1.0,
target_data=target_data,
gen_data=sample_data,
target_data_norm=target_data,
@@ -82,3 +89,41 @@ def test_i_time_start_gets_correct_time_longer_windows(
np.testing.assert_allclose(
float(logs["metrics/weighted_bias/a"]), float(target_time), rtol=1e-5
)
+
+
+def test_loss():
+ """
+ Basic test the aggregator combines loss correctly
+ with multiple batches and no distributed training.
+ """
+ torch.manual_seed(0)
+ example_data = {
+ "a": torch.randn(1, 2, 5, 5, device=get_device()),
+ }
+ area_weights = torch.ones(1).to(get_device())
+ aggregator = MeanAggregator(LatLonOperations(area_weights))
+ aggregator.record_batch(
+ loss=1.0,
+ target_data=example_data,
+ gen_data=example_data,
+ target_data_norm=example_data,
+ gen_data_norm=example_data,
+ )
+ aggregator.record_batch(
+ loss=2.0,
+ target_data=example_data,
+ gen_data=example_data,
+ target_data_norm=example_data,
+ gen_data_norm=example_data,
+ )
+ logs = aggregator.get_logs(label="metrics")
+ assert logs["metrics/loss"] == 1.5
+ aggregator.record_batch(
+ loss=3.0,
+ target_data=example_data,
+ gen_data=example_data,
+ target_data_norm=example_data,
+ gen_data_norm=example_data,
+ )
+ logs = aggregator.get_logs(label="metrics")
+ assert logs["metrics/loss"] == 2.0
diff --git a/fme/fme/core/aggregator/plotting.py b/fme/fme/ace/aggregator/plotting.py
similarity index 60%
rename from fme/fme/core/aggregator/plotting.py
rename to fme/fme/ace/aggregator/plotting.py
index a4cfc8d..de44c8b 100644
--- a/fme/fme/core/aggregator/plotting.py
+++ b/fme/fme/ace/aggregator/plotting.py
@@ -6,7 +6,7 @@
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
-from fme.core.wandb import WandB
+from fme.core.wandb import Image, WandB
def get_cmap_limits(data: np.ndarray, diverging=False) -> Tuple[float, float]:
@@ -27,6 +27,10 @@ def plot_imshow(
use_colorbar: bool = True,
) -> Figure:
"""Plot a 2D array using imshow, ensuring figure size is same as array size."""
+ min_ = np.min(data) if vmin is None else vmin
+ max_ = np.max(data) if vmax is None else vmax
+ if len(data.shape) == 3:
+ data = fold_healpix_data(data, fill_value=0.5 * (min_ + max_))
if flip_lat:
lat_dim = -2
data = np.flip(data, axis=lat_dim)
@@ -34,8 +38,6 @@ def plot_imshow(
if use_colorbar:
height, width = data.shape
colorbar_width = max(1, int(0.025 * width))
- min_ = np.min(data) if vmin is None else vmin
- max_ = np.max(data) if vmax is None else vmax
range_ = np.linspace(min_, max_, height)
range_ = np.repeat(range_[:, np.newaxis], repeats=colorbar_width, axis=1)
range_ = np.flipud(range_) # wandb images start from top (and left)
@@ -51,11 +53,55 @@ def plot_imshow(
return fig
+def fold_healpix_data(data: np.ndarray, fill_value: float) -> np.ndarray:
+ if data.shape[0] != 12:
+ raise ValueError(
+ "first dimension must be 12 (face) for healpix data, "
+ f"got shape {data.shape}"
+ )
+ # we want to panel the data like this, numbered by first dimension index
+ # -----------------
+ # | | | | |
+ # | | | |3 |
+ # -----------------
+ # | | | | |
+ # | | |2 |7 |
+ # -----------------
+ # | | | | |
+ # | |1 |6 |10 |
+ # -----------------
+ # | | | | |
+ # |0 |5 |9 | |
+ # -----------------
+ # | | | | |
+ # |4 |8 | | |
+ # -----------------
+ # | | | | |
+ # |11 | | | |
+ # -----------------
+ blank_panel = np.full_like(data[0], fill_value)
+ panels = [
+ [blank_panel, blank_panel, blank_panel, data[3]],
+ [blank_panel, blank_panel, data[2], data[7]],
+ [blank_panel, data[1], data[6], data[10]],
+ [data[0], data[5], data[9], blank_panel],
+ [data[4], data[8], blank_panel, blank_panel],
+ [data[11], blank_panel, blank_panel, blank_panel],
+ ]
+ return np.concatenate([np.concatenate(row, axis=1) for row in panels], axis=0)
+
+
+def fold_if_healpix_data(data: np.ndarray, fill_value: float) -> np.ndarray:
+ if data.shape[0] == 12:
+ return fold_healpix_data(data, fill_value)
+ return data
+
+
def plot_paneled_data(
data: List[List[np.ndarray]],
diverging: bool,
caption: Optional[str] = None,
-) -> Figure:
+) -> Image:
"""Plot a list of 2D data arrays in a paneled plot."""
if diverging:
cmap = "RdBu_r"
@@ -77,7 +123,11 @@ def plot_paneled_data(
caption += f"vmin={vmin:.4g}, vmax={vmax:.4g}."
- all_data = _stitch_data_panels(data, vmin=vmin)
+ if diverging:
+ fill_value = 0.5 * (vmin + vmax)
+ else:
+ fill_value = vmin
+ all_data = _stitch_data_panels(data, fill_value=fill_value)
fig = plot_imshow(all_data, vmin=vmin, vmax=vmax, cmap=cmap)
wandb = WandB.get_instance()
@@ -91,10 +141,14 @@ def plot_paneled_data(
return wandb_image
-def _stitch_data_panels(data: List[List[np.ndarray]], vmin) -> np.ndarray:
+def _stitch_data_panels(data: List[List[np.ndarray]], fill_value) -> np.ndarray:
for row in data:
if len(row) != len(data[0]):
raise ValueError("All rows must have the same number of panels.")
+ data = [
+ [fold_if_healpix_data(arr, fill_value=fill_value) for arr in row]
+ for row in data
+ ]
n_rows = len(data)
n_cols = len(data[0])
for row in data:
@@ -102,14 +156,12 @@ def _stitch_data_panels(data: List[List[np.ndarray]], vmin) -> np.ndarray:
if arr.shape != data[0][0].shape:
raise ValueError("All panels must have the same shape.")
- stitched_data = (
- np.zeros(
- (
- n_rows * data[0][0].shape[0] + n_rows - 1,
- n_cols * data[0][0].shape[1] + n_cols - 1,
- )
- )
- + vmin
+ stitched_data = np.full(
+ (
+ n_rows * data[0][0].shape[0] + n_rows - 1,
+ n_cols * data[0][0].shape[1] + n_cols - 1,
+ ),
+ fill_value=fill_value,
)
# iterate over rows backwards, as the image starts in the bottom left
diff --git a/fme/fme/ace/aggregator/test_plotting.py b/fme/fme/ace/aggregator/test_plotting.py
new file mode 100644
index 0000000..0143fb6
--- /dev/null
+++ b/fme/fme/ace/aggregator/test_plotting.py
@@ -0,0 +1,110 @@
+import numpy as np
+import pytest
+
+from .plotting import (
+ _stitch_data_panels,
+ fold_healpix_data,
+ get_cmap_limits,
+ plot_imshow,
+ plot_paneled_data,
+)
+
+
+def test_cmap_limits():
+ data = np.array([1, 2, 3])
+ vmin, vmax = get_cmap_limits(data)
+ assert vmin == 1
+ assert vmax == 3
+
+
+def test_cmap_limits_diverging():
+ data = np.array([-1, 2, 3])
+ vmin, vmax = get_cmap_limits(data, diverging=True)
+ assert vmin == -3
+ assert vmax == 3
+
+
+@pytest.mark.parametrize("use_colorbar", [True, False])
+def test_plot_imshow(use_colorbar):
+ shape = [10, 15]
+ data = np.random.randn(*shape)
+ fig = plot_imshow(np.array(data), use_colorbar=use_colorbar)
+ width, height = (fig.get_size_inches() * fig.dpi).astype(int)
+ if use_colorbar:
+ # colorbar is no more than 15% of the width but greater than 0 pixels
+ assert shape[1] < width <= int(shape[1] * 1.15)
+ assert height == shape[0]
+ else:
+ assert [height, width] == shape
+
+
+def test_fold_healpix_data():
+ face_shape = [2, 3]
+ data = np.random.randn(12, *face_shape)
+ folded = fold_healpix_data(data, fill_value=0)
+ expected_shape = (6 * face_shape[0], 4 * face_shape[1])
+ assert folded.shape == expected_shape
+
+
+@pytest.mark.parametrize("use_colorbar", [True, False])
+def test_plot_imshow_healpix(use_colorbar):
+ face_shape = [4, 6]
+ shape = [6 * face_shape[0], 4 * face_shape[1]]
+ data = np.random.randn(12, *face_shape)
+ fig = plot_imshow(np.array(data), use_colorbar=use_colorbar)
+ width, height = (fig.get_size_inches() * fig.dpi).astype(int)
+ if use_colorbar:
+ # colorbar is no more than 15% of the width but greater than 0 pixels
+ assert shape[1] < width <= int(shape[1] * 1.15)
+ assert height == shape[0]
+ else:
+ assert [height, width] == shape
+
+
+def test_stitch_data_panels():
+ data = [
+ [np.array([[1, 2]]), np.array([[3, 4]])],
+ [np.array([[5, 6]]), np.array([[7, 8]])],
+ ]
+ stitched = _stitch_data_panels(data, fill_value=1)
+ expected = np.array(
+ [ # vertical orientation is swapped as data starts from bottom-left
+ [5, 6, 1, 7, 8],
+ [1, 1, 1, 1, 1],
+ [1, 2, 1, 3, 4],
+ ]
+ )
+ assert np.array_equal(stitched, expected)
+
+
+@pytest.mark.parametrize(
+ "shape, img_shape",
+ [
+ pytest.param(
+ [12, 2, 3],
+ [
+ 27, # 3 * 4 + 1 (divider) + 3 * 4 + 2 (colorbar)
+ 25, # 2 * 6 + 1 (divider) + 2 * 6 + 2 (colorbar)
+ ],
+ id="healpix",
+ ),
+ pytest.param(
+ [2, 3],
+ [
+ 9, # 3 + 1 (divider) + 3 + 2 (colorbar)
+ 5, # 2 + 1 (divider) + 2 (colorbar)
+ ],
+ id="latlon",
+ ),
+ ],
+)
+def test_plot_paneled_data(shape, img_shape):
+ panel = np.random.uniform(size=shape)
+ data = [
+ [panel, panel],
+ [panel, panel],
+ ]
+ fig = plot_paneled_data(data, diverging=False)
+ assert np.array_equal(fig.image.size, img_shape)
+ fig = plot_paneled_data(data, diverging=True)
+ assert np.array_equal(fig.image.size, img_shape)
diff --git a/fme/fme/core/aggregator/train.py b/fme/fme/ace/aggregator/train.py
similarity index 73%
rename from fme/fme/core/aggregator/train.py
rename to fme/fme/ace/aggregator/train.py
index 6959615..bcecb69 100644
--- a/fme/fme/core/aggregator/train.py
+++ b/fme/fme/ace/aggregator/train.py
@@ -1,10 +1,14 @@
+from typing import Dict
+
import torch
+from fme.ace.stepper import TrainOutput
from fme.core.device import get_device
from fme.core.distributed import Distributed
+from fme.core.generics.aggregator import AggregatorABC
-class TrainAggregator:
+class TrainAggregator(AggregatorABC[TrainOutput]):
"""
Aggregates statistics for the first timestep.
@@ -17,12 +21,12 @@ def __init__(self):
self._loss = torch.tensor(0.0, device=get_device())
@torch.no_grad()
- def record_batch(self, loss: float):
- self._loss += loss
+ def record_batch(self, batch: TrainOutput):
+ self._loss += batch.metrics["loss"]
self._n_batches += 1
@torch.no_grad()
- def get_logs(self, label: str):
+ def get_logs(self, label: str) -> Dict[str, torch.Tensor]:
"""
Returns logs as can be reported to WandB.
diff --git a/fme/fme/core/data_loading/__init__.py b/fme/fme/ace/data_loading/__init__.py
similarity index 100%
rename from fme/fme/core/data_loading/__init__.py
rename to fme/fme/ace/data_loading/__init__.py
diff --git a/fme/fme/ace/data_loading/batch_data.py b/fme/fme/ace/data_loading/batch_data.py
new file mode 100644
index 0000000..bf177ce
--- /dev/null
+++ b/fme/fme/ace/data_loading/batch_data.py
@@ -0,0 +1,571 @@
+import dataclasses
+import datetime
+import logging
+from typing import (
+ Any,
+ Callable,
+ Collection,
+ Dict,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Sequence,
+ Sized,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+import numpy as np
+import torch
+import xarray as xr
+from torch.utils.data import default_collate
+
+from fme.ace.requirements import PrognosticStateDataRequirements
+from fme.core.coordinates import HorizontalCoordinates, HybridSigmaPressureCoordinate
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.dataset.xarray import DatasetProperties
+from fme.core.device import get_device
+from fme.core.generics.data import DataLoader, GriddedDataABC, InferenceDataABC
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.typing_ import TensorDict, TensorMapping
+
+SelfType = TypeVar("SelfType", bound="BatchData")
+
+
+def _check_device(data: TensorMapping, device: torch.device):
+ for v in data.values():
+ if v.device != device:
+ raise ValueError(f"data must be on {device}")
+
+
+class PrognosticState:
+ """
+ Thin typing wrapper around BatchData to indicate that the data is a prognostic
+ state, such as an initial condition or final state when evolving forward in time.
+ """
+
+ def __init__(self, data: "BatchData"):
+ """
+ Initialize the state.
+
+ Args:
+ data: The data to initialize the state with.
+ """
+ self._data = data
+
+ def to_device(self) -> "PrognosticState":
+ return PrognosticState(self._data.to_device())
+
+ def as_batch_data(self) -> "BatchData":
+ return self._data
+
+
+@dataclasses.dataclass
+class BatchData:
+ """A container for the data and time coordinates of a batch.
+
+ Parameters:
+ data: Data for each variable in each sample of shape (sample, time, ...),
+ concatenated along samples to make a batch. To be used directly in training,
+ validation, and inference.
+ time: An array representing time coordinates for each sample in the batch,
+ concatenated along samples to make a batch. To be used in writing out
+ inference predictions with time coordinates, not directly in ML.
+ horizontal_dims: Horizontal dimensions of the data. Used for writing to
+ netCDF files.
+ """
+
+ data: TensorMapping
+ time: xr.DataArray
+ horizontal_dims: List[str] = dataclasses.field(
+ default_factory=lambda: ["lat", "lon"]
+ )
+
+ @property
+ def dims(self) -> List[str]:
+ return ["sample", "time"] + self.horizontal_dims
+
+ def to_device(self) -> "BatchData":
+ return self.__class__(
+ data={k: v.to(get_device()) for k, v in self.data.items()},
+ time=self.time,
+ horizontal_dims=self.horizontal_dims,
+ )
+
+ @classmethod
+ def _get_kwargs(cls, horizontal_dims: Optional[List[str]]) -> Dict[str, Any]:
+ if horizontal_dims is None:
+ kwargs = {}
+ else:
+ kwargs = {"horizontal_dims": horizontal_dims}
+ return kwargs
+
+ @classmethod
+ def new_on_cpu(
+ cls,
+ data: TensorMapping,
+ time: xr.DataArray,
+ horizontal_dims: Optional[List[str]] = None,
+ ) -> "BatchData":
+ _check_device(data, torch.device("cpu"))
+ kwargs = cls._get_kwargs(horizontal_dims)
+ return BatchData(
+ data=data,
+ time=time,
+ **kwargs,
+ )
+
+ @classmethod
+ def new_on_device(
+ cls,
+ data: TensorMapping,
+ time: xr.DataArray,
+ horizontal_dims: Optional[List[str]] = None,
+ ) -> "BatchData":
+ """
+ Move the data to the current global device specified by get_device().
+ """
+ _check_device(data, get_device())
+ kwargs = cls._get_kwargs(horizontal_dims)
+ return BatchData(
+ data=data,
+ time=time,
+ **kwargs,
+ )
+
+ def __post_init__(self):
+ if len(self.time.shape) != 2:
+ raise ValueError(
+ "Expected time to have shape (n_samples, n_times), got shape "
+ f"{self.time.shape}."
+ )
+ for k, v in self.data.items():
+ if v.shape[:2] != self.time.shape[:2]:
+ raise ValueError(
+ f"Data for variable {k} has shape {v.shape}, expected shape "
+ f"(n_samples, n_times) for time but got shape "
+ f"{self.time.shape}."
+ )
+
+ @classmethod
+ def from_sample_tuples(
+ cls,
+ samples: Sequence[Tuple[TensorMapping, xr.DataArray]],
+ sample_dim_name: str = "sample",
+ horizontal_dims: Optional[List[str]] = None,
+ ) -> "BatchData":
+ sample_data, sample_times = zip(*samples)
+ batch_data = default_collate(sample_data)
+ batch_time = xr.concat(sample_times, dim=sample_dim_name)
+ return BatchData.new_on_cpu(
+ data=batch_data,
+ time=batch_time,
+ horizontal_dims=horizontal_dims,
+ )
+
+ def compute_derived_variables(
+ self: SelfType,
+ derive_func: Callable[[TensorMapping, TensorMapping], TensorDict],
+ forcing_data: SelfType,
+ ) -> SelfType:
+ """
+ Compute derived variables from the data and forcing data.
+
+ The forcing data must have the same time coordinate as the batch data.
+
+ Args:
+ derive_func: A function that takes the data and forcing data and returns a
+ dictionary of derived variables.
+ forcing_data: The forcing data to compute derived variables from.
+ """
+ if not np.all(forcing_data.time.values == self.time.values):
+ raise ValueError(
+ "Forcing data must have the same time coordinate as the batch data."
+ )
+ derived_data = derive_func(self.data, forcing_data.data)
+ return self.__class__(
+ data={**self.data, **derived_data},
+ time=self.time,
+ horizontal_dims=self.horizontal_dims,
+ )
+
+ def remove_initial_condition(self: SelfType, n_ic_timesteps: int) -> SelfType:
+ """
+ Remove the initial condition timesteps from the data.
+ """
+ if n_ic_timesteps == 0:
+ raise RuntimeError("No initial condition timesteps to remove.")
+ return self.__class__(
+ {k: v[:, n_ic_timesteps:] for k, v in self.data.items()},
+ time=self.time.isel(time=slice(n_ic_timesteps, None)),
+ horizontal_dims=self.horizontal_dims,
+ )
+
+ def subset_names(self: SelfType, names: Collection[str]) -> SelfType:
+ """
+ Subset the data to only include the given names.
+ """
+ return self.__class__(
+ {k: v for k, v in self.data.items() if k in names},
+ time=self.time,
+ horizontal_dims=self.horizontal_dims,
+ )
+
+ def get_start(
+ self: SelfType, prognostic_names: Collection[str], n_ic_timesteps: int
+ ) -> PrognosticState:
+ """
+ Get the initial condition state.
+ """
+ return PrognosticState(
+ self.subset_names(prognostic_names).select_time_slice(
+ slice(0, n_ic_timesteps)
+ )
+ )
+
+ def get_end(
+ self: SelfType, prognostic_names: Collection[str], n_ic_timesteps: int
+ ) -> PrognosticState:
+ """
+ Get the final state which can be used as a new initial condition.
+ """
+ return PrognosticState(
+ self.subset_names(prognostic_names).select_time_slice(
+ slice(-n_ic_timesteps, None)
+ )
+ )
+
+ def select_time_slice(self: SelfType, time_slice: slice) -> SelfType:
+ """
+ Select a window of data from the batch.
+ """
+ return self.__class__(
+ {k: v[:, time_slice] for k, v in self.data.items()},
+ time=self.time[:, time_slice],
+ horizontal_dims=self.horizontal_dims,
+ )
+
+ def prepend(self: SelfType, initial_condition: PrognosticState) -> SelfType:
+ """
+ Prepend the initial condition to the data.
+ """
+ initial_batch_data = initial_condition.as_batch_data()
+ filled_data = {**initial_batch_data.data}
+ example_tensor = list(initial_batch_data.data.values())[0]
+ state_data_device = list(self.data.values())[0].device
+ for k in self.data:
+ if k not in filled_data:
+ filled_data[k] = torch.full_like(example_tensor, fill_value=np.nan)
+ return self.__class__(
+ data={
+ k: torch.cat([filled_data[k].to(state_data_device), v], dim=1)
+ for k, v in self.data.items()
+ },
+ time=xr.concat([initial_batch_data.time, self.time], dim="time"),
+ horizontal_dims=self.horizontal_dims,
+ )
+
+
+@dataclasses.dataclass
+class PairedData:
+ """A container for the data and time coordinate of a batch, with paired
+ prediction and target data.
+ """
+
+ prediction: TensorMapping
+ target: TensorMapping
+ time: xr.DataArray
+
+ @classmethod
+ def from_batch_data(
+ cls,
+ prediction: BatchData,
+ target: BatchData,
+ ) -> "PairedData":
+ if not np.all(prediction.time.values == target.time.values):
+ raise ValueError("Prediction and target time coordinate must be the same.")
+ return PairedData(prediction.data, target.data, prediction.time)
+
+ @classmethod
+ def new_on_device(
+ cls,
+ prediction: TensorMapping,
+ target: TensorMapping,
+ time: xr.DataArray,
+ ) -> "PairedData":
+ device = get_device()
+ _check_device(prediction, device)
+ _check_device(target, device)
+ return PairedData(prediction, target, time)
+
+ @classmethod
+ def new_on_cpu(
+ cls,
+ prediction: TensorMapping,
+ target: TensorMapping,
+ time: xr.DataArray,
+ ) -> "PairedData":
+ _check_device(prediction, torch.device("cpu"))
+ _check_device(target, torch.device("cpu"))
+ return PairedData(prediction, target, time)
+
+
+T = TypeVar("T", covariant=True)
+
+
+U = TypeVar("U")
+
+
+class SizedMap(Generic[T, U], Sized, Iterable[U]):
+ def __init__(self, func: Callable[[T], U], iterable: DataLoader[T]):
+ self._func = func
+ self._iterable = iterable
+
+ def __len__(self) -> int:
+ return len(self._iterable)
+
+ def __iter__(self) -> Iterator[U]:
+ return map(self._func, self._iterable)
+
+
+def get_initial_condition(
+ loader: DataLoader[BatchData],
+ requirements: PrognosticStateDataRequirements,
+) -> PrognosticState:
+ for batch in loader:
+ return batch.to_device().get_start(
+ prognostic_names=requirements.names,
+ n_ic_timesteps=requirements.n_timesteps,
+ )
+ raise ValueError("No initial condition found in loader")
+
+
+class InferenceGriddedData(InferenceDataABC[PrognosticState, BatchData]):
+ """
+ Data as required for inference.
+
+ All data exposed from this class is on the current device.
+ """
+
+ def __init__(
+ self,
+ loader: DataLoader[BatchData],
+ initial_condition: Union[PrognosticState, PrognosticStateDataRequirements],
+ properties: DatasetProperties,
+ ):
+ """
+ Args:
+ loader: torch DataLoader, which returns batches of type
+ TensorMapping where keys indicate variable name.
+ Each tensor has shape
+ [batch_size, face, time_window_size, n_channels, n_x_coord, n_y_coord].
+ Data can be on any device (but will typically be on CPU).
+ initial_condition: Initial condition for the inference, or a requirements
+ object specifying how to extract the initial condition from the first
+ batch of data. Data can be on any device.
+ properties: Batch-constant properties for the dataset, such as variable
+ metadata and coordinate information. Data can be on any device.
+
+ Note:
+ While input data can be on any device, all data exposed from this class
+ will be on the current device.
+ """
+ self._loader = loader
+ self._properties = properties.to_device()
+ self._n_initial_conditions: Optional[int] = None
+ if isinstance(initial_condition, PrognosticStateDataRequirements):
+ self._initial_condition: PrognosticState = get_initial_condition(
+ loader, initial_condition
+ )
+ else:
+ self._initial_condition = initial_condition.to_device()
+
+ @property
+ def loader(self) -> DataLoader[BatchData]:
+ def on_device(batch: BatchData) -> BatchData:
+ return batch.to_device()
+
+ return SizedMap(on_device, self._loader)
+
+ @property
+ def variable_metadata(self) -> Mapping[str, VariableMetadata]:
+ return self._properties.variable_metadata
+
+ @property
+ def vertical_coordinate(self) -> HybridSigmaPressureCoordinate:
+ return self._properties.vertical_coordinate
+
+ @property
+ def horizontal_coordinates(self) -> HorizontalCoordinates:
+ return self._properties.horizontal_coordinates
+
+ @property
+ def timestep(self) -> datetime.timedelta:
+ return self._properties.timestep
+
+ @property
+ def coords(self) -> Mapping[str, np.ndarray]:
+ return {
+ **self.horizontal_coordinates.coords,
+ **self.vertical_coordinate.coords,
+ }
+
+ @property
+ def gridded_operations(self) -> GriddedOperations:
+ return self.horizontal_coordinates.gridded_operations
+
+ @property
+ def _n_samples(self) -> int:
+ return len(self._loader.dataset) # type: ignore
+
+ @property
+ def _n_batches(self) -> int:
+ return len(self._loader) # type: ignore
+
+ @property
+ def _first_time(self) -> Any:
+ return self._loader.dataset[0][1].values[0] # type: ignore
+
+ @property
+ def _last_time(self) -> Any:
+ return self._loader.dataset[-1][1].values[0] # type: ignore
+
+ @property
+ def n_initial_conditions(self) -> int:
+ if self._n_initial_conditions is None:
+ example_data = next(iter(self.loader)).data
+ example_tensor = next(iter(example_data.values()))
+ self._n_initial_conditions = example_tensor.shape[0]
+ return self._n_initial_conditions
+
+ @property
+ def initial_condition(self) -> PrognosticState:
+ return self._initial_condition
+
+ def log_info(self, name: str):
+ logging.info(
+ f"{name} data: {self._n_samples} samples, " f"{self._n_batches} batches"
+ )
+ logging.info(f"{name} data: first sample's initial time: {self._first_time}")
+ logging.info(f"{name} data: last sample's initial time: {self._last_time}")
+
+
+class GriddedData(GriddedDataABC[BatchData]):
+ """
+ Data as required for pytorch training.
+
+ The data is assumed to be gridded, and attributes are included for
+ performing operations on gridded data.
+
+ All data exposed from this class is on the current device.
+ """
+
+ def __init__(
+ self,
+ loader: DataLoader[BatchData],
+ properties: DatasetProperties,
+ sampler: Optional[torch.utils.data.Sampler] = None,
+ ):
+ """
+ Args:
+ loader: torch DataLoader, which returns batches of type
+ TensorMapping where keys indicate variable name.
+ Each tensor has shape
+ [batch_size, face, time_window_size, n_channels, n_x_coord, n_y_coord].
+ Data can be on any device (but will typically be on CPU).
+ properties: Batch-constant properties for the dataset, such as variable
+ metadata and coordinate information. Data can be on any device.
+ sampler: Optional sampler for the data loader. Provided to allow support for
+ distributed training.
+
+ Note:
+ While input data can be on any device, all data exposed from this class
+ will be on the current device.
+ """
+ self._loader = loader
+ self._properties = properties.to_device()
+ self._sampler = sampler
+ self._batch_size: Optional[int] = None
+
+ @property
+ def loader(self) -> DataLoader[BatchData]:
+ def on_device(batch: BatchData) -> BatchData:
+ return batch.to_device()
+
+ return SizedMap(on_device, self._loader)
+
+ @property
+ def variable_metadata(self) -> Mapping[str, VariableMetadata]:
+ return self._properties.variable_metadata
+
+ @property
+ def vertical_coordinate(self) -> HybridSigmaPressureCoordinate:
+ return self._properties.vertical_coordinate
+
+ @property
+ def horizontal_coordinates(self) -> HorizontalCoordinates:
+ return self._properties.horizontal_coordinates
+
+ @property
+ def timestep(self) -> datetime.timedelta:
+ return self._properties.timestep
+
+ @property
+ def coords(self) -> Mapping[str, np.ndarray]:
+ return {
+ **self.horizontal_coordinates.coords,
+ **self.vertical_coordinate.coords,
+ }
+
+ @property
+ def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]:
+ return self.horizontal_coordinates.grid
+
+ @property
+ def gridded_operations(self) -> GriddedOperations:
+ return self.horizontal_coordinates.gridded_operations
+
+ @property
+ def n_samples(self) -> int:
+ return len(self._loader.dataset) # type: ignore
+
+ @property
+ def n_batches(self) -> int:
+ return len(self._loader) # type: ignore
+
+ @property
+ def _first_time(self) -> Any:
+ return self._loader.dataset[0][1].values[0] # type: ignore
+
+ @property
+ def _last_time(self) -> Any:
+ return self._loader.dataset[-1][1].values[0] # type: ignore
+
+ @property
+ def batch_size(self) -> int:
+ if self._batch_size is None:
+ example_data = next(iter(self.loader)).data
+ example_tensor = next(iter(example_data.values()))
+ self._batch_size = example_tensor.shape[0]
+ return self._batch_size
+
+ def log_info(self, name: str):
+ logging.info(
+ f"{name} data: {self.n_samples} samples, " f"{self.n_batches} batches"
+ )
+ logging.info(f"{name} data: first sample's initial time: {self._first_time}")
+ logging.info(f"{name} data: last sample's initial time: {self._last_time}")
+
+ def set_epoch(self, epoch: int):
+ """
+ Set the epoch for the data loader sampler, if it is a distributed sampler.
+ """
+ if self._sampler is not None and isinstance(
+ self._sampler, torch.utils.data.DistributedSampler
+ ):
+ self._sampler.set_epoch(epoch)
diff --git a/fme/fme/ace/data_loading/config.py b/fme/fme/ace/data_loading/config.py
new file mode 100644
index 0000000..292d356
--- /dev/null
+++ b/fme/fme/ace/data_loading/config.py
@@ -0,0 +1,34 @@
+import dataclasses
+from typing import Optional, Sequence
+
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.distributed import Distributed
+
+
+@dataclasses.dataclass
+class DataLoaderConfig:
+ """
+ Parameters:
+ dataset: A sequence of configurations each defining a dataset
+ to be loaded. This sequence of datasets will be concatenated.
+ batch_size: Number of samples per batch.
+ num_data_workers: Number of parallel workers to use for data loading.
+ prefetch_factor: how many batches a single data worker will attempt to
+ hold in host memory at a given time.
+ strict_ensemble: Whether to enforce that the datasets to be concatened
+ have the same dimensions and coordinates.
+ """
+
+ dataset: Sequence[XarrayDataConfig]
+ batch_size: int
+ num_data_workers: int = 0
+ prefetch_factor: Optional[int] = None
+ strict_ensemble: bool = True
+
+ def __post_init__(self):
+ dist = Distributed.get_instance()
+ if self.batch_size % dist.world_size != 0:
+ raise ValueError(
+ "batch_size must be divisible by the number of parallel "
+ f"workers, got {self.batch_size} and {dist.world_size}"
+ )
diff --git a/fme/fme/ace/data_loading/getters.py b/fme/fme/ace/data_loading/getters.py
new file mode 100644
index 0000000..557296f
--- /dev/null
+++ b/fme/fme/ace/data_loading/getters.py
@@ -0,0 +1,206 @@
+import logging
+from typing import Optional, Union
+
+import torch.utils.data
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import RandomSampler
+
+from fme.ace.data_loading.batch_data import BatchData
+from fme.ace.requirements import PrognosticStateDataRequirements
+from fme.core.dataset.getters import get_dataset
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.dataset.xarray import XarrayDataset
+from fme.core.device import using_gpu
+from fme.core.distributed import Distributed
+
+from .batch_data import GriddedData, InferenceGriddedData, PrognosticState
+from .config import DataLoaderConfig
+from .inference import (
+ ExplicitIndices,
+ ForcingDataLoaderConfig,
+ InferenceDataLoaderConfig,
+ InferenceDataset,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def get_data_loader(
+ config: DataLoaderConfig,
+ train: bool,
+ requirements: DataRequirements,
+) -> GriddedData:
+ """
+ Args:
+ config: Parameters for the data loader.
+ train: Whether loader is intended for training or validation data; if True,
+ then data will be shuffled.
+ requirements: Data requirements for the model.
+ """
+ dataset, properties = get_dataset(
+ config.dataset, requirements, strict=config.strict_ensemble
+ )
+ dist = Distributed.get_instance()
+
+ if dist.is_distributed():
+ sampler = DistributedSampler(dataset, shuffle=train)
+ else:
+ sampler = RandomSampler(dataset) if train else None
+
+ if properties.is_remote:
+ # GCSFS and S3FS are not fork-safe, so we need to use forkserver
+ mp_context = "forkserver"
+ persistent_workers = True
+ else:
+ mp_context = None
+ persistent_workers = False
+
+ def collate_fn(samples):
+ return BatchData.from_sample_tuples(
+ samples,
+ horizontal_dims=list(properties.horizontal_coordinates.dims),
+ )
+
+ batch_size = dist.local_batch_size(int(config.batch_size))
+
+ if config.prefetch_factor is None:
+ # DataLoader default is not None so we must leave it unset
+ kwargs = {}
+ else:
+ kwargs = {"prefetch_factor": config.prefetch_factor}
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=config.num_data_workers,
+ sampler=sampler,
+ drop_last=True,
+ pin_memory=using_gpu(),
+ collate_fn=collate_fn,
+ multiprocessing_context=mp_context,
+ persistent_workers=persistent_workers,
+ **kwargs,
+ )
+
+ if len(dataloader) == 0:
+ raise ValueError(
+ "No batches in dataloader: "
+ f"{len(dataloader.dataset)} samples, {len(dataloader)} batches. "
+ f"Batch size is {dataloader.batch_size}"
+ )
+
+ return GriddedData(
+ loader=dataloader,
+ properties=properties,
+ sampler=sampler,
+ )
+
+
+def get_inference_data(
+ config: InferenceDataLoaderConfig,
+ total_forward_steps: int,
+ window_requirements: DataRequirements,
+ initial_condition: Union[PrognosticState, PrognosticStateDataRequirements],
+ surface_temperature_name: Optional[str] = None,
+ ocean_fraction_name: Optional[str] = None,
+) -> InferenceGriddedData:
+ """
+ Args:
+ config: Parameters for the data loader.
+ total_forward_steps: Total number of forward steps to take over the course of
+ inference.
+ window_requirements: Data requirements for the model.
+ initial_condition: Initial condition for the inference, or a requirements object
+ specifying how to extract the initial condition from the first batch of
+ data
+ surface_temperature_name: Name of the surface temperature variable. Can be
+ set to None if no ocean temperature prescribing is being used.
+ ocean_fraction_name: Name of the ocean fraction variable. Can be set to None
+ if no ocean temperature prescribing is being used.
+
+ Returns:
+ A data loader for inference with coordinates and metadata.
+ """
+ dataset = InferenceDataset(
+ config,
+ total_forward_steps,
+ window_requirements,
+ surface_temperature_name,
+ ocean_fraction_name,
+ )
+ properties = dataset.properties
+
+ if properties.is_remote:
+ # GCSFS and S3FS are not fork-safe, so we need to use forkserver
+ # persist workers since startup is slow
+ mp_context = "forkserver"
+ persistent_workers = True
+ else:
+ mp_context = None
+ persistent_workers = False
+
+ logging.info(f"Multiprocessing inference context: {mp_context or 'fork'}")
+
+ # we roll our own batching in InferenceDataset, which is why batch_size=None below
+ loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=None,
+ num_workers=config.num_data_workers,
+ shuffle=False,
+ pin_memory=using_gpu(),
+ multiprocessing_context=mp_context,
+ persistent_workers=persistent_workers,
+ )
+ gridded_data = InferenceGriddedData(
+ loader=loader,
+ initial_condition=initial_condition,
+ properties=properties,
+ )
+
+ return gridded_data
+
+
+def get_forcing_data(
+ config: ForcingDataLoaderConfig,
+ total_forward_steps: int,
+ window_requirements: DataRequirements,
+ initial_condition: PrognosticState,
+ surface_temperature_name: Optional[str] = None,
+ ocean_fraction_name: Optional[str] = None,
+) -> InferenceGriddedData:
+ """Return a GriddedData loader for forcing data based on the initial condition.
+ This function determines the start indices for the forcing data based on the initial
+ time in the provided initial condition.
+
+ Args:
+ config: Parameters for the forcing data loader.
+ total_forward_steps: Total number of forward steps to take over the course of
+ inference.
+ window_requirements: Data requirements for the forcing data.
+ initial_condition: Initial condition for the inference.
+ surface_temperature_name: Name of the surface temperature variable. Can be
+ set to None if no ocean temperature prescribing is being used.
+ ocean_fraction_name: Name of the ocean fraction variable. Can be set to None
+ if no ocean temperature prescribing is being used.
+
+ Returns:
+ A data loader for forcing data with coordinates and metadata.
+ """
+ initial_time = initial_condition.as_batch_data().time
+ if initial_time.shape[1] != 1:
+ raise NotImplementedError("code assumes initial time only has 1 timestep")
+ available_times = XarrayDataset(config.dataset, window_requirements).all_times
+ start_time_indices = []
+ for time in initial_time.values[:, 0]:
+ start_time_indices.append(available_times.get_loc(time))
+ inference_config = config.build_inference_config(
+ start_indices=ExplicitIndices(start_time_indices)
+ )
+ return get_inference_data(
+ config=inference_config,
+ total_forward_steps=total_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition,
+ surface_temperature_name=surface_temperature_name,
+ ocean_fraction_name=ocean_fraction_name,
+ )
diff --git a/fme/fme/ace/data_loading/inference.py b/fme/fme/ace/data_loading/inference.py
new file mode 100644
index 0000000..312d286
--- /dev/null
+++ b/fme/fme/ace/data_loading/inference.py
@@ -0,0 +1,295 @@
+import dataclasses
+import logging
+from math import ceil
+from typing import Optional, Sequence, Union
+
+import cftime
+import numpy as np
+import torch
+import xarray as xr
+
+from fme.ace.data_loading.batch_data import BatchData
+from fme.ace.data_loading.perturbation import SSTPerturbation
+from fme.core.coordinates import LatLonCoordinates
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.dataset.xarray import DatasetProperties, XarrayDataset
+from fme.core.distributed import Distributed
+from fme.core.typing_ import Slice
+
+
+@dataclasses.dataclass
+class TimestampList:
+ """
+ Configuration for a list of timestamps.
+
+ Parameters:
+ times: List of timestamps.
+ timestamp_format: Format of the timestamps.
+ """
+
+ times: Sequence[str]
+ timestamp_format: str = "%Y-%m-%dT%H:%M:%S"
+
+ def as_indices(self, time_index: xr.CFTimeIndex) -> np.ndarray:
+ datetimes = [
+ cftime.datetime.strptime(
+ t, self.timestamp_format, calendar=time_index.calendar
+ )
+ for t in self.times
+ ]
+ (indices,) = time_index.isin(datetimes).nonzero()
+ if len(indices) != len(self.times):
+ missing_times = set(datetimes) - set(time_index[indices])
+ raise ValueError(
+ f"Inference initial condition timestamps {missing_times} "
+ "were not found in the dataset."
+ )
+ return indices
+
+ @property
+ def n_initial_conditions(self) -> int:
+ return len(self.times)
+
+
+@dataclasses.dataclass
+class InferenceInitialConditionIndices:
+ """
+ Configuration of the indices for initial conditions during inference.
+
+ Parameters:
+ n_initial_conditions: Number of initial conditions to use.
+ first: Index of the first initial condition.
+ interval: Interval between initial conditions.
+ """
+
+ n_initial_conditions: int
+ first: int = 0
+ interval: int = 1
+
+ def __post_init__(self):
+ if self.interval < 0:
+ raise ValueError("interval must be positive")
+
+ def as_indices(self) -> np.ndarray:
+ stop = self.n_initial_conditions * self.interval + self.first
+ return np.arange(self.first, stop, self.interval)
+
+
+@dataclasses.dataclass
+class ExplicitIndices:
+ """
+ Configure indices providing them explicitly.
+
+ Parameters:
+ list: List of integer indices.
+ """
+
+ list: Sequence[int]
+
+ def as_indices(self) -> np.ndarray:
+ return np.array(self.list)
+
+ @property
+ def n_initial_conditions(self) -> int:
+ return len(self.list)
+
+
+@dataclasses.dataclass
+class InferenceDataLoaderConfig:
+ """
+ Configuration for inference data.
+
+ This is like the `DataLoaderConfig` class, but with some additional
+ constraints. During inference, we have only one batch, so the number of
+ samples directly determines the size of that batch.
+
+ Parameters:
+ dataset: Configuration to define the dataset.
+ start_indices: Configuration of the indices for initial conditions
+ during inference. This can be a list of timestamps, a list of
+ integer indices, or a slice configuration of the integer indices.
+ Values following the initial condition will still come from
+ the full dataset.
+ num_data_workers: Number of parallel workers to use for data loading.
+ perturbations: Configuration for SST perturbations.
+ persistence_names: Names of variables for which all returned values
+ will be the same as the initial condition. When evaluating initial
+ condition predictability, set this to forcing variables that should
+ not be updated during inference (e.g. surface temperature).
+ """
+
+ dataset: XarrayDataConfig
+ start_indices: Union[
+ InferenceInitialConditionIndices, ExplicitIndices, TimestampList
+ ]
+ num_data_workers: int = 0
+ perturbations: Optional[SSTPerturbation] = None
+ persistence_names: Optional[Sequence[str]] = None
+
+ def __post_init__(self):
+ if self.dataset.subset != Slice(None, None, None):
+ raise ValueError("Inference data may not be subset.")
+
+ @property
+ def n_initial_conditions(self) -> int:
+ return self.start_indices.n_initial_conditions
+
+
+@dataclasses.dataclass
+class ForcingDataLoaderConfig:
+ """
+ Configuration for the forcing data.
+
+ Parameters:
+ dataset: Configuration to define the dataset.
+ num_data_workers: Number of parallel workers to use for data loading.
+ perturbations: Configuration for SST perturbations
+ used in forcing data.
+ persistence_names: Names of variables for which all returned values
+ will be the same as the initial condition. When evaluating initial
+ condition predictability, set this to forcing variables that should
+ not be updated during inference (e.g. surface temperature).
+ """
+
+ dataset: XarrayDataConfig
+ num_data_workers: int = 0
+ perturbations: Optional[SSTPerturbation] = None
+ persistence_names: Optional[Sequence[str]] = None
+
+ def __post_init__(self):
+ if self.dataset.subset != Slice(None, None, None):
+ raise ValueError("Inference data may not be subset.")
+
+ def build_inference_config(self, start_indices: ExplicitIndices):
+ return InferenceDataLoaderConfig(
+ dataset=self.dataset,
+ num_data_workers=self.num_data_workers,
+ start_indices=start_indices,
+ perturbations=self.perturbations,
+ persistence_names=self.persistence_names,
+ )
+
+
+class InferenceDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ config: InferenceDataLoaderConfig,
+ total_forward_steps: int,
+ requirements: DataRequirements,
+ surface_temperature_name: Optional[str] = None,
+ ocean_fraction_name: Optional[str] = None,
+ ):
+ dataset = XarrayDataset(config.dataset, requirements=requirements)
+ self._dataset = dataset
+ self._properties = dataset.properties
+ self._forward_steps_in_memory = requirements.n_timesteps - 1
+ self._total_forward_steps = total_forward_steps
+ self._perturbations = config.perturbations
+ self._surface_temperature_name = surface_temperature_name
+ self._ocean_fraction_name = ocean_fraction_name
+ self._n_initial_conditions = config.n_initial_conditions
+
+ if isinstance(config.start_indices, TimestampList):
+ self._start_indices = config.start_indices.as_indices(dataset.all_times)
+ else:
+ self._start_indices = config.start_indices.as_indices()
+ self._validate_n_forward_steps()
+ if isinstance(self._properties.horizontal_coordinates, LatLonCoordinates):
+ self._lats, self._lons = self._properties.horizontal_coordinates.meshgrid
+ else:
+ if self._perturbations is not None:
+ raise ValueError(
+ "Currently, SST perturbations are only supported \
+ for lat/lon coordinates."
+ )
+ if self._perturbations is not None and (
+ self._surface_temperature_name is None or self._ocean_fraction_name is None
+ ):
+ raise ValueError(
+ "No ocean configuration found, \
+ SST perturbations require an ocean configuration."
+ )
+
+ self._persistence_data: Optional[BatchData] = None
+ if config.persistence_names is not None:
+ first_sample = self._get_batch_data(0)
+ self._persistence_data = first_sample.subset_names(
+ config.persistence_names
+ ).select_time_slice(slice(0, 1))
+
+ def _get_batch_data(self, index) -> BatchData:
+ dist = Distributed.get_instance()
+ i_start = index * self._forward_steps_in_memory
+ sample_tuples = []
+ for i_member in range(self._n_initial_conditions):
+ # check if sample is one this local rank should process
+ if i_member % dist.world_size != dist.rank:
+ continue
+ i_window_start = i_start + self._start_indices[i_member]
+ i_window_end = i_window_start + self._forward_steps_in_memory + 1
+ if i_window_end > (
+ self._total_forward_steps + self._start_indices[i_member]
+ ):
+ i_window_end = (
+ self._total_forward_steps + self._start_indices[i_member] + 1
+ )
+ window_time_slice = slice(i_window_start, i_window_end)
+ tensors, time = self._dataset.get_sample_by_time_slice(window_time_slice)
+ if self._perturbations is not None:
+ if (
+ self._surface_temperature_name is None
+ or self._ocean_fraction_name is None
+ ):
+ raise ValueError(
+ "Surface temperature and ocean fraction names must be provided \
+ to apply SST perturbations."
+ )
+ logging.debug("Applying SST perturbations to forcing data")
+ for perturbation in self._perturbations.perturbations:
+ perturbation.apply_perturbation(
+ tensors[self._surface_temperature_name],
+ self._lats,
+ self._lons,
+ tensors[self._ocean_fraction_name],
+ )
+ sample_tuples.append((tensors, time))
+ return BatchData.from_sample_tuples(
+ sample_tuples,
+ horizontal_dims=list(self.properties.horizontal_coordinates.dims),
+ )
+
+ def __getitem__(self, index) -> BatchData:
+ dist = Distributed.get_instance()
+ result = self._get_batch_data(index)
+ if self._persistence_data is not None:
+ updated_data = {}
+ for key, value in self._persistence_data.data.items():
+ updated_data[key] = value.expand_as(result.data[key])
+ result.data = {**result.data, **updated_data}
+ assert result.time.shape[0] == self._n_initial_conditions // dist.world_size
+ return result
+
+ def __len__(self) -> int:
+ # The ceil is necessary so if the last batch is smaller
+ # than the rest the ratio will be rounded up and the last batch
+ # will be included in the loading
+ return int(ceil(self._total_forward_steps / self._forward_steps_in_memory))
+
+ @property
+ def properties(self) -> DatasetProperties:
+ return self._properties
+
+ @property
+ def n_forward_steps(self) -> int:
+ return self._total_forward_steps
+
+ def _validate_n_forward_steps(self):
+ max_steps = self._dataset.total_timesteps - max(self._start_indices) - 1
+ if self._total_forward_steps > max_steps:
+ raise ValueError(
+ f"The number of forward inference steps ({self._total_forward_steps}) "
+ "must be less than or equal to the number of possible steps "
+ f"({max_steps}) in dataset after the last initial condition's "
+ "start index."
+ )
diff --git a/fme/fme/ace/data_loading/perturbation.py b/fme/fme/ace/data_loading/perturbation.py
new file mode 100644
index 0000000..d21021d
--- /dev/null
+++ b/fme/fme/ace/data_loading/perturbation.py
@@ -0,0 +1,194 @@
+import abc
+import dataclasses
+from typing import Any, Callable, ClassVar, Mapping, Tuple, Type, TypeVar
+
+import dacite
+import numpy as np
+import torch
+
+from fme.core.registry.registry import Registry
+
+
+@dataclasses.dataclass
+class PerturbationConfig(abc.ABC):
+ """
+ Returns a perturbation function config class.
+ """
+
+ @classmethod
+ def from_state(cls, state: Mapping[str, Any]) -> "PerturbationConfig":
+ """
+ Create a PerturbationSelector from a dictionary containing all the information
+ needed to build a PerturbationConfig.
+ """
+ return dacite.from_dict(
+ data_class=cls, data=state, config=dacite.Config(strict=True)
+ )
+
+ @abc.abstractmethod
+ def apply_perturbation(
+ self,
+ data: torch.Tensor,
+ lat: torch.Tensor,
+ lon: torch.Tensor,
+ ocean_fraction: torch.Tensor,
+ ) -> None: ...
+
+
+PT = TypeVar("PT", bound=Type[PerturbationConfig])
+
+
+@dataclasses.dataclass
+class PerturbationSelector:
+ type: str
+ config: Mapping[str, Any]
+ registry: ClassVar[Registry] = Registry()
+
+ def __post__init(self):
+ if self.registry is not Registry():
+ raise ValueError("PerturbationSelector.registry should not be set manually")
+
+ @classmethod
+ def register(cls, type_name) -> Callable[[PT], PT]:
+ return cls.registry.register(type_name)
+
+ def build(self) -> PerturbationConfig:
+ return self.registry.from_dict(self.get_state())
+
+ def get_state(self) -> Mapping[str, Any]:
+ """
+ Get a dictionary containing all the information needed to build a
+ PerturbationConfig.
+
+ """
+ return {"type": self.type, "config": self.config}
+
+ @classmethod
+ def get_available_types(cls):
+ """This class method is used to expose all available types of Perturbations."""
+ return cls(type="", config={}).registry._types.keys()
+
+
+@dataclasses.dataclass
+class SSTPerturbation:
+ """
+ Configuration for sea surface temperature perturbations
+ applied to initial condition and forcing data.
+ Currently, this is strictly applied to both.
+
+ Parameters:
+ sst: List of perturbation selectors for SST perturbations.
+ """
+
+ sst: list[PerturbationSelector]
+
+ def __post_init__(self):
+ self.perturbations: list[PerturbationConfig] = [
+ perturbation.build() for perturbation in self.sst
+ ]
+
+
+def _get_ocean_mask(ocean_fraction: torch.Tensor, cutoff: float = 0.5) -> torch.Tensor:
+ return ocean_fraction > cutoff
+
+
+@PerturbationSelector.register("constant")
+@dataclasses.dataclass
+class ConstantConfig(PerturbationConfig):
+ """
+ Configuration for a constant perturbation.
+ """
+
+ amplitude: float = 1.0
+
+ def apply_perturbation(
+ self,
+ data: torch.Tensor,
+ lat: torch.Tensor,
+ lon: torch.Tensor,
+ ocean_fraction: torch.Tensor,
+ ):
+ ocean_mask = _get_ocean_mask(ocean_fraction)
+ data[ocean_mask] += self.amplitude # type: ignore
+
+
+@PerturbationSelector.register("greens_function")
+@dataclasses.dataclass
+class GreensFunctionConfig(PerturbationConfig):
+ """
+ Configuration for a single sinusoidal patch of a Green's function perturbation.
+ See equation 1 in Bloch‐Johnson, J., et al. (2024).
+
+ Parameters:
+ amplitude: The amplitude of the perturbation,
+ maximum is reached at (lat_center, lon_center).
+ lat_center: The latitude at the center of the patch in degrees.
+ lon_center: The longitude at the center of the patch in degrees.
+ lat_width: latitudinal width of the patch in degrees.
+ lon_width: longitudinal width of the patch in degrees.
+ """
+
+ amplitude: float = 1.0
+ lat_center: float = 0.0
+ lon_center: float = 0.0
+ lat_width: float = 10.0
+ lon_width: float = 10.0
+
+ def __post_init__(self):
+ self._lat_center_rad = np.deg2rad(self.lat_center)
+ self._lon_center_rad = np.deg2rad(self.lon_center)
+ self._lat_width_rad = np.deg2rad(self.lat_width)
+ self._lon_width_rad = np.deg2rad(self.lon_width)
+
+ def _wrap_longitude_discontinuity(
+ self,
+ lon: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Assume longitude is in the range [0, 360) degrees.
+ If the patch crosses the discontinuity at 0/360 degrees,
+ shift the longitude accordingly.
+ """
+ lon_min = self.lon_center - self.lon_width / 2.0
+ lon_max = self.lon_center + self.lon_width / 2.0
+ if lon_min < 0:
+ lon_shifted = ((lon + 180) % 360) - 180
+ lon_in_patch = (lon_shifted > lon_min) & (lon_shifted < lon_max)
+ elif lon_max > 360:
+ lon_in_patch = (lon > lon_min) | (lon < lon_max % 360)
+ lon_shifted = ((lon + 180) % 360) + 180
+ else:
+ lon_in_patch = (lon > lon_min) & (lon < lon_max)
+ lon_shifted = lon
+ return lon_in_patch, lon_shifted
+
+ def apply_perturbation(
+ self,
+ data: torch.Tensor,
+ lat: torch.Tensor,
+ lon: torch.Tensor,
+ ocean_fraction: torch.Tensor,
+ ):
+ lat_in_patch = torch.abs(lat - self.lat_center) < self.lat_width / 2.0
+ lon_in_patch, lon_shifted = self._wrap_longitude_discontinuity(lon)
+ mask = lat_in_patch & lon_in_patch
+ ocean_mask = _get_ocean_mask(ocean_fraction)
+ perturbation = self.amplitude * (
+ torch.cos(
+ torch.pi
+ / 2
+ * (lat.deg2rad() - self._lat_center_rad)
+ / (self._lat_width_rad / 2.0)
+ )
+ ** 2
+ * torch.cos(
+ torch.pi
+ / 2
+ * (lon_shifted.deg2rad() - self._lon_center_rad)
+ / (self._lon_width_rad / 2.0)
+ )
+ ** 2
+ )
+ mask = mask.expand(data.shape)
+ perturbation = perturbation.expand(data.shape)
+ data[mask & ocean_mask] += perturbation[mask & ocean_mask]
diff --git a/fme/fme/ace/data_loading/test_batch_data.py b/fme/fme/ace/data_loading/test_batch_data.py
new file mode 100644
index 0000000..d4a2643
--- /dev/null
+++ b/fme/fme/ace/data_loading/test_batch_data.py
@@ -0,0 +1,147 @@
+from typing import List
+
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+
+from fme.ace.data_loading.batch_data import BatchData
+from fme.core.device import get_device
+
+
+def assert_metadata_equal(a: BatchData, b: BatchData):
+ assert a.horizontal_dims == b.horizontal_dims
+
+
+def get_batch_data(
+ names: List[str],
+ n_samples: int,
+ n_times: int,
+ horizontal_dims: List[str],
+ n_lat: int = 8,
+ n_lon: int = 16,
+):
+ device = get_device()
+ return BatchData(
+ data={
+ name: torch.randn(n_samples, n_times, n_lat, n_lon, device=device)
+ for name in names
+ },
+ time=xr.DataArray(np.random.rand(n_samples, n_times), dims=["sample", "time"]),
+ horizontal_dims=horizontal_dims,
+ )
+
+
+@pytest.mark.parametrize(
+ "names, prognostic_names",
+ [
+ pytest.param(["prog"], ["prog"], id="all prognostic"),
+ pytest.param(["prog", "forcing"], ["prog"], id="some prognostic"),
+ pytest.param(["forcing1", "forcing2"], [], id="no prognostic"),
+ ],
+)
+@pytest.mark.parametrize("n_ic_timesteps", [1, 2])
+def test_get_start(names: List[str], prognostic_names: List[str], n_ic_timesteps: int):
+ n_samples = 2
+ n_times = 5
+ n_lat = 8
+ n_lon = 16
+ horizontal_dims = ["lat", "lon"]
+ batch_data = get_batch_data(
+ names=names,
+ n_samples=n_samples,
+ n_times=n_times,
+ horizontal_dims=horizontal_dims,
+ n_lat=n_lat,
+ n_lon=n_lon,
+ )
+ start = batch_data.get_start(prognostic_names, n_ic_timesteps).as_batch_data()
+ assert_metadata_equal(start, batch_data)
+ assert start.time.equals(batch_data.time.isel(time=slice(0, n_ic_timesteps)))
+ assert set(start.data.keys()) == set(prognostic_names)
+ for name in prognostic_names:
+ assert start.data[name].shape == (n_samples, n_ic_timesteps, n_lat, n_lon)
+ np.testing.assert_allclose(
+ start.data[name].cpu().numpy(),
+ batch_data.data[name][:, :n_ic_timesteps, ...].cpu().numpy(),
+ )
+
+
+@pytest.mark.parametrize(
+ "names, prognostic_names",
+ [
+ pytest.param(["prog"], ["prog"], id="all prognostic"),
+ pytest.param(["prog", "forcing"], ["prog"], id="some prognostic"),
+ pytest.param(["forcing1", "forcing2"], [], id="no prognostic"),
+ ],
+)
+@pytest.mark.parametrize("n_ic_timesteps", [1, 2])
+def test_get_end(names: List[str], prognostic_names: List[str], n_ic_timesteps: int):
+ n_samples = 2
+ n_times = 5
+ n_lat = 8
+ n_lon = 16
+ horizontal_dims = ["lat", "lon"]
+ batch_data = get_batch_data(
+ names=names,
+ n_samples=n_samples,
+ n_times=n_times,
+ horizontal_dims=horizontal_dims,
+ n_lat=n_lat,
+ n_lon=n_lon,
+ )
+ end = batch_data.get_end(prognostic_names, n_ic_timesteps).as_batch_data()
+ assert_metadata_equal(end, batch_data)
+ assert end.time.equals(batch_data.time.isel(time=slice(-n_ic_timesteps, None)))
+ assert set(end.data.keys()) == set(prognostic_names)
+ for name in prognostic_names:
+ assert end.data[name].shape == (n_samples, n_ic_timesteps, n_lat, n_lon)
+ np.testing.assert_allclose(
+ end.data[name].cpu().numpy(),
+ batch_data.data[name][:, -n_ic_timesteps:, ...].cpu().numpy(),
+ )
+
+
+@pytest.mark.parametrize(
+ "names, prepend_names",
+ [
+ pytest.param(["prog"], ["prog"], id="all prepended"),
+ pytest.param(["prog", "forcing"], ["prog"], id="some prepended"),
+ ],
+)
+@pytest.mark.parametrize("n_ic_timesteps", [1, 2])
+def test_prepend(names: List[str], prepend_names: List[str], n_ic_timesteps: int):
+ n_samples = 2
+ n_times = 5
+ n_lat = 8
+ n_lon = 16
+ horizontal_dims = ["lat", "lon"]
+ batch_data = get_batch_data(
+ names=names,
+ n_samples=n_samples,
+ n_times=n_times,
+ horizontal_dims=horizontal_dims,
+ n_lat=n_lat,
+ n_lon=n_lon,
+ )
+ start_data = batch_data.get_start(prepend_names, n_ic_timesteps)
+ prepended = batch_data.prepend(start_data)
+ start_batch_data = start_data.as_batch_data()
+ assert_metadata_equal(prepended, batch_data)
+ assert prepended.time.isel(time=slice(n_ic_timesteps, None)).equals(batch_data.time)
+ assert set(prepended.data.keys()) == set(names)
+ for name in names:
+ np.testing.assert_allclose(
+ prepended.data[name][:, n_ic_timesteps:, ...].cpu().numpy(),
+ batch_data.data[name].cpu().numpy(),
+ )
+ for name in prepend_names:
+ np.testing.assert_allclose(
+ prepended.data[name][:, :n_ic_timesteps, ...].cpu().numpy(),
+ start_batch_data.data[name].cpu().numpy(),
+ )
+ assert prepended.time.shape == (n_samples, n_times + n_ic_timesteps)
+ for name in set(names) - set(prepend_names):
+ assert np.all(
+ np.isnan(prepended.data[name][:, :n_ic_timesteps, ...].cpu().numpy())
+ )
diff --git a/fme/fme/core/data_loading/test_data_loader.py b/fme/fme/ace/data_loading/test_data_loader.py
similarity index 59%
rename from fme/fme/core/data_loading/test_data_loader.py
rename to fme/fme/ace/data_loading/test_data_loader.py
index 81ce1ed..3d261df 100644
--- a/fme/fme/core/data_loading/test_data_loader.py
+++ b/fme/fme/ace/data_loading/test_data_loader.py
@@ -1,8 +1,8 @@
"""This file contains unit tests related to creating torch Datasets from climate
data (e.g. netCDF files)."""
-import datetime
import math
+import os
import pathlib
from typing import List
@@ -12,55 +12,68 @@
import torch
import xarray as xr
-from fme.core.data_loading.config import DataLoaderConfig, Slice, XarrayDataConfig
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.data_loading.getters import (
+import fme
+from fme.ace.data_loading.batch_data import BatchData, PrognosticState
+from fme.ace.data_loading.config import DataLoaderConfig
+from fme.ace.data_loading.getters import (
get_data_loader,
get_forcing_data,
get_inference_data,
)
-from fme.core.data_loading.inference import (
+from fme.ace.data_loading.inference import (
ExplicitIndices,
ForcingDataLoaderConfig,
InferenceDataLoaderConfig,
+ InferenceDataset,
InferenceInitialConditionIndices,
TimestampList,
)
-from fme.core.data_loading.requirements import DataRequirements
-from fme.core.data_loading.utils import BatchData, get_times
+from fme.ace.data_loading.perturbation import PerturbationSelector, SSTPerturbation
+from fme.ace.requirements import PrognosticStateDataRequirements
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.typing_ import Slice
-def _get_coords(dim_sizes, calendar):
+def _get_coords(dim_sizes, calendar, timestep_size=1):
coords = {}
for dim_name, size in dim_sizes.items():
if dim_name == "time":
dtype = np.int64
+ step = timestep_size
+ size = size * step
attrs = {"calendar": calendar, "units": "days since 1970-01-01"}
else:
dtype = np.float32
+ step = 1
attrs = {}
- coord_value = np.arange(size, dtype=dtype)
+ coord_value = np.arange(0, size, step, dtype=dtype)
coord = xr.DataArray(coord_value, dims=(dim_name,), attrs=attrs)
coords[dim_name] = coord
return coords
-def _save_netcdf(filename, dim_sizes, variable_names, calendar):
+def _save_netcdf(filename, dim_sizes, variable_names, calendar, timestep_size=1):
data_vars = {}
for name in variable_names:
- data = np.random.randn(*list(dim_sizes.values()))
+ if name == "constant_mask":
+ data = np.ones(list(dim_sizes.values()))
+ else:
+ data = np.random.randn(*list(dim_sizes.values()))
if len(dim_sizes) > 0:
data = data.astype(np.float32) # type: ignore
data_vars[name] = xr.DataArray(
data, dims=list(dim_sizes), attrs={"units": "m", "long_name": name}
)
- coords = _get_coords(dim_sizes, calendar)
+ coords = _get_coords(dim_sizes, calendar, timestep_size)
for i in range(7):
data_vars[f"ak_{i}"] = float(i)
data_vars[f"bk_{i}"] = float(i + 1)
ds = xr.Dataset(data_vars=data_vars, coords=coords)
ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC")
+ return ds
def _create_dataset_on_disk(
@@ -68,6 +81,7 @@ def _create_dataset_on_disk(
calendar: str = "proleptic_gregorian",
data_dim_sizes=None,
n_times: int = 3,
+ timestep_size: int = 1,
) -> pathlib.Path:
if data_dim_sizes is None:
data_dim_sizes = {"time": n_times, "grid_yt": 16, "grid_xt": 32}
@@ -76,10 +90,14 @@ def _create_dataset_on_disk(
in_variable_names = ["foo", "bar", "baz"]
out_variable_names = ["foo", "bar"]
mask_name = "mask"
- all_variable_names = list(set(in_variable_names + out_variable_names)) + [mask_name]
+ constant_mask_name = "constant_mask"
+ all_variable_names = list(set(in_variable_names + out_variable_names)) + [
+ mask_name,
+ constant_mask_name,
+ ]
data_path = data_dir / "data.nc"
- _save_netcdf(data_path, data_dim_sizes, all_variable_names, calendar)
+ _save_netcdf(data_path, data_dim_sizes, all_variable_names, calendar, timestep_size)
return data_path
@@ -108,8 +126,8 @@ def test_ensemble_loader(tmp_path, num_ensemble_members=3):
samples_per_member = n_timesteps - window_timesteps + 1
data = get_data_loader(config, True, requirements)
- assert len(data.loader) == samples_per_member * num_ensemble_members
- assert isinstance(data.sigma_coordinates, SigmaCoordinates)
+ assert data.n_batches == samples_per_member * num_ensemble_members
+ assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate)
def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1):
@@ -139,12 +157,12 @@ def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1
requirements = DataRequirements(["foo"], window_timesteps)
data = get_data_loader(config, True, requirements)
- assert len(data.loader) == n_samples * num_ensemble_members
- assert isinstance(data.sigma_coordinates, SigmaCoordinates)
+ assert data.n_batches == n_samples * num_ensemble_members
+ assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate)
def test_xarray_loader(tmp_path):
- """Checks that sigma coordinates are present."""
+ """Checks that vertical coordinates are present."""
_create_dataset_on_disk(tmp_path)
config = DataLoaderConfig(
[XarrayDataConfig(data_path=tmp_path, n_repeats=1)],
@@ -154,13 +172,14 @@ def test_xarray_loader(tmp_path):
window_timesteps = 2 # 1 initial condition and 1 step forward
requirements = DataRequirements(["foo"], window_timesteps)
data = get_data_loader(config, True, requirements) # type: ignore
- assert isinstance(data.sigma_coordinates, SigmaCoordinates)
+ assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate)
+ assert data.vertical_coordinate.ak.device == fme.get_device()
def test_xarray_loader_hpx(tmp_path):
- """Checks that sigma coordinates are present."""
+ """Checks that vertical coordinates are present."""
n_times = 3
- data_dim_sizes = {"time": n_times, "face": 12, "width": 64, "height": 64}
+ data_dim_sizes = {"time": n_times, "face": 12, "width": 16, "height": 16}
_create_dataset_on_disk(tmp_path, data_dim_sizes=data_dim_sizes, n_times=n_times)
config = DataLoaderConfig(
[
@@ -177,9 +196,10 @@ def test_xarray_loader_hpx(tmp_path):
for batch in data.loader:
assert batch is not None
# expect healpix shape
- assert batch.data["foo"].shape == (1, window_timesteps, 12, 64, 64)
+ assert batch.data["foo"].shape == (1, window_timesteps, 12, 16, 16)
break
- assert isinstance(data.sigma_coordinates, SigmaCoordinates)
+ assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate)
+ assert data.vertical_coordinate.ak.device == fme.get_device()
def test_loader_n_repeats_but_not_infer_timestep_error(tmp_path):
@@ -206,25 +226,42 @@ def test_inference_data_loader(tmp_path):
),
)
n_forward_steps_in_memory = 3
- requirements = DataRequirements(["foo"], n_timesteps=7)
- data_loader = get_inference_data(
+ window_requirements = DataRequirements(
+ names=["foo", "bar"],
+ n_timesteps=n_forward_steps_in_memory + 1,
+ )
+ initial_condition_requirements = PrognosticStateDataRequirements(
+ names=["foo"],
+ n_timesteps=1,
+ )
+ data = get_inference_data(
config,
- forward_steps_in_memory=n_forward_steps_in_memory,
- requirements=requirements,
- ).loader
+ total_forward_steps=6,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
+ )
+ data_loader = data.loader
batch_data = next(iter(data_loader))
assert isinstance(batch_data, BatchData)
- assert isinstance(batch_data.data["foo"], torch.Tensor)
- assert batch_data.data["foo"].shape[0] == batch_size
- assert batch_data.data["foo"].shape[1] == n_forward_steps_in_memory + 1
- assert batch_data.data["foo"].shape[2] == 16
- assert batch_data.data["foo"].shape[3] == 32
- assert isinstance(batch_data.times, xr.DataArray)
- assert list(batch_data.times.dims) == ["sample", "time"]
- assert batch_data.times.sizes["sample"] == batch_size
- assert batch_data.times.sizes["time"] == n_forward_steps_in_memory + 1
- assert batch_data.times.dt.calendar == "proleptic_gregorian"
- assert len(data_loader) == 2
+ for name in ["foo", "bar"]:
+ assert isinstance(batch_data.data[name], torch.Tensor)
+ assert batch_data.data[name].shape == (
+ batch_size,
+ n_forward_steps_in_memory + 1,
+ 16,
+ 32,
+ )
+ assert isinstance(batch_data.time, xr.DataArray)
+ assert list(batch_data.time.dims) == ["sample", "time"]
+ assert batch_data.time.sizes["sample"] == batch_size
+ assert batch_data.time.sizes["time"] == n_forward_steps_in_memory + 1
+ assert batch_data.time.dt.calendar == "proleptic_gregorian"
+ assert data._n_batches == 2
+ assert data.vertical_coordinate.ak.device == fme.get_device()
+ initial_condition = data.initial_condition.as_batch_data()
+ assert isinstance(initial_condition, BatchData)
+ assert "bar" not in initial_condition.data
+ assert initial_condition.data["foo"].shape == (batch_size, 1, 16, 32)
@pytest.fixture(params=["julian", "proleptic_gregorian", "noleap"])
@@ -255,26 +292,11 @@ def test_data_loader_outputs(tmp_path, calendar):
assert isinstance(batch_data, BatchData)
assert isinstance(batch_data.data["foo"], torch.Tensor)
assert batch_data.data["foo"].shape[0] == n_samples
- assert isinstance(batch_data.times, xr.DataArray)
- assert list(batch_data.times.dims) == ["sample", "time"]
- assert batch_data.times.sizes["sample"] == n_samples
- assert batch_data.times.sizes["time"] == window_timesteps
- assert batch_data.times.dt.calendar == calendar
-
-
-def test_get_times_non_cftime():
- """
- Check that `get_times` raises an error when the time coordinate is not
- cftime.datetime
- """
- n_times = 5
- times = [datetime.datetime(2020, 1, 1, i, 0, 0) for i in range(n_times)]
- ds = xr.Dataset(
- {"foo": xr.DataArray(np.arange(n_times), dims=("time",))},
- coords={"time": times},
- )
- with pytest.raises(AssertionError):
- get_times(ds, 0, 1)
+ assert isinstance(batch_data.time, xr.DataArray)
+ assert list(batch_data.time.dims) == ["sample", "time"]
+ assert batch_data.time.sizes["sample"] == n_samples
+ assert batch_data.time.sizes["time"] == window_timesteps
+ assert batch_data.time.dt.calendar == calendar
@pytest.mark.parametrize(
@@ -320,20 +342,29 @@ def test_inference_data_loader_validate_n_forward_steps(
start_indices=start_indices,
)
n_forward_steps_in_memory = num_forward_steps
- requirements = DataRequirements(["foo"], n_timesteps=num_forward_steps + 1)
+ window_requirements = DataRequirements(
+ names=["foo", "bar"],
+ n_timesteps=n_forward_steps_in_memory + 1,
+ )
+ initial_condition_requirements = PrognosticStateDataRequirements(
+ names=["foo"],
+ n_timesteps=1,
+ )
if raises_error:
with pytest.raises(ValueError):
get_inference_data(
config,
- forward_steps_in_memory=n_forward_steps_in_memory,
- requirements=requirements,
+ total_forward_steps=num_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
)
else:
get_inference_data(
config,
- forward_steps_in_memory=n_forward_steps_in_memory,
- requirements=requirements,
+ total_forward_steps=num_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
)
@@ -368,27 +399,42 @@ def test_get_forcing_data(tmp_path, n_initial_conditions):
forward_steps_in_memory = 2
_create_dataset_on_disk(tmp_path, calendar=calendar, n_times=10)
config = ForcingDataLoaderConfig(XarrayDataConfig(data_path=tmp_path))
- requirements = DataRequirements(["foo"], total_forward_steps + 1)
+ window_requirements = DataRequirements(
+ names=["foo"],
+ n_timesteps=forward_steps_in_memory + 1,
+ )
time_values = [
- cftime.datetime(1970, 1, 1 + 2 * n, calendar=calendar)
+ [cftime.datetime(1970, 1, 1 + 2 * n, calendar=calendar)]
for n in range(n_initial_conditions)
]
- initial_times = xr.DataArray(time_values, dims=["sample"])
- data = get_forcing_data(
- config, forward_steps_in_memory, requirements, initial_times
+ initial_condition = BatchData.new_on_cpu(
+ data={"foo": torch.randn(n_initial_conditions, 1, 1, 1)},
+ time=xr.DataArray(time_values, dims=["sample", "time"]),
)
- assert len(data.loader.dataset) == math.ceil(
- total_forward_steps / forward_steps_in_memory
+ data = get_forcing_data(
+ config,
+ total_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=PrognosticState(initial_condition),
)
+ assert data._n_samples == math.ceil(total_forward_steps / forward_steps_in_memory)
batch_data = next(iter(data.loader))
assert isinstance(batch_data, BatchData)
assert isinstance(batch_data.data["foo"], torch.Tensor)
assert set(batch_data.data.keys()) == {"foo"}
assert batch_data.data["foo"].shape[0] == len(time_values)
assert batch_data.data["foo"].shape[1] == forward_steps_in_memory + 1
- assert list(batch_data.times.dims) == ["sample", "time"]
- xr.testing.assert_allclose(batch_data.times.isel(time=0), initial_times)
- assert batch_data.times.dt.calendar == calendar
+ assert list(batch_data.time.dims) == ["sample", "time"]
+ xr.testing.assert_allclose(batch_data.time[:, 0], initial_condition.time[:, 0])
+ assert batch_data.time.dt.calendar == calendar
+ xr.testing.assert_equal(
+ data.initial_condition.as_batch_data().time,
+ initial_condition.time,
+ )
+ np.testing.assert_allclose(
+ data.initial_condition.as_batch_data().data["foo"].cpu().numpy(),
+ initial_condition.data["foo"].cpu().numpy(),
+ )
def test_inference_loader_raises_if_subset():
@@ -446,3 +492,77 @@ def test_TimestampList_as_indices(timestamps, expected_indices):
np.testing.assert_equal(
timestamp_list.as_indices(time_index), np.array(expected_indices)
)
+
+
+def test_inference_data_with_perturbations(tmp_path):
+ _create_dataset_on_disk(tmp_path, n_times=14)
+ batch_size = 1
+ step = 7
+ config = InferenceDataLoaderConfig(
+ XarrayDataConfig(
+ data_path=tmp_path,
+ n_repeats=1,
+ ),
+ start_indices=InferenceInitialConditionIndices(
+ first=0, n_initial_conditions=batch_size, interval=step
+ ),
+ perturbations=SSTPerturbation(
+ sst=[PerturbationSelector(type="constant", config={"amplitude": 2.0})]
+ ),
+ )
+ n_forward_steps_in_memory = 3
+ original_foo = xr.open_dataset(os.path.join(tmp_path, "data.nc"))["foo"].values[
+ 0 : n_forward_steps_in_memory + 1, :, :
+ ]
+ window_requirements = DataRequirements(
+ names=["foo", "constant_mask"],
+ n_timesteps=n_forward_steps_in_memory + 1,
+ )
+ initial_condition_requirements = PrognosticStateDataRequirements(
+ names=["foo"],
+ n_timesteps=1,
+ )
+ data = get_inference_data(
+ config,
+ total_forward_steps=6,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
+ surface_temperature_name="foo",
+ ocean_fraction_name="constant_mask",
+ )
+ batch_data = next(iter(data.loader))
+ np.testing.assert_allclose(
+ original_foo + 2.0,
+ batch_data.data["foo"].cpu().numpy()[0, :, :, :],
+ )
+ np.testing.assert_allclose(
+ original_foo[:1, :, :] + 2.0,
+ data.initial_condition.as_batch_data().data["foo"].cpu().numpy()[0, :, :, :],
+ )
+
+
+def test_inference_persistence_names(tmp_path):
+ _create_dataset_on_disk(tmp_path, n_times=14)
+
+ config = InferenceDataLoaderConfig(
+ XarrayDataConfig(data_path=tmp_path),
+ start_indices=ExplicitIndices([0, 3]),
+ persistence_names=["foo"],
+ )
+ window_requirements = DataRequirements(
+ names=["foo", "bar"],
+ n_timesteps=3,
+ )
+ dataset = InferenceDataset(
+ config,
+ 9,
+ requirements=window_requirements,
+ )
+ first_item = dataset[0].data
+ second_item = dataset[1].data
+ # ensure first and second time steps are the same
+ torch.testing.assert_close(first_item["foo"][:, 0], first_item["foo"][:, 1])
+ # ensure the entire first and second returned items
+ torch.testing.assert_close(first_item["foo"], second_item["foo"])
+ # ensure this is not the case for another variable
+ assert not torch.all(first_item["bar"] == second_item["bar"])
diff --git a/fme/fme/ace/data_loading/test_data_loading_config.py b/fme/fme/ace/data_loading/test_data_loading_config.py
new file mode 100644
index 0000000..cfbd067
--- /dev/null
+++ b/fme/fme/ace/data_loading/test_data_loading_config.py
@@ -0,0 +1,69 @@
+from datetime import timedelta
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from fme.core.dataset.config import RepeatedInterval
+
+
+def test_repeated_interval_int():
+ interval = RepeatedInterval(interval_length=3, block_length=6, start=0)
+ mask = interval.get_boolean_mask(length=18)
+ expected_mask = np.array([True, True, True, False, False, False] * 3)
+ np.testing.assert_array_equal(mask, expected_mask)
+
+
+def test_repeated_interval_str():
+ interval = RepeatedInterval(interval_length="1d", block_length="7d", start="2d")
+ mask = interval.get_boolean_mask(length=21, timestep=timedelta(days=1))
+ expected_mask = np.array([False, False, True, False, False, False, False] * 3)
+ np.testing.assert_array_equal(mask, expected_mask)
+
+
+def test_repeated_interval_mixed_types():
+ with pytest.raises(ValueError):
+ RepeatedInterval(interval_length=3, block_length="6d", start=0)
+
+
+@pytest.mark.parametrize("interval, block, start", [(4, 6, 3), ("2d", "3d", "2d")])
+def test_repeated_interval_invalid_interval_start(interval, block, start):
+ """start + interval exceeds length of block"""
+ interval = RepeatedInterval(
+ interval_length=interval, block_length=block, start=start
+ )
+ with pytest.raises(ValueError):
+ interval.get_boolean_mask(length=18, timestep=timedelta(days=1))
+
+
+def test_repeated_interval_zero_length():
+ interval = RepeatedInterval(interval_length=0, block_length=6, start=0)
+ mask = interval.get_boolean_mask(length=18)
+ expected_mask = np.array([False] * 18)
+ np.testing.assert_array_equal(mask, expected_mask)
+
+
+def test_repeated_interval_partial_block():
+ interval = RepeatedInterval(interval_length=3, block_length=6, start=0)
+ mask = interval.get_boolean_mask(length=20)
+ expected_mask = np.array([True, True, True, False, False, False] * 3 + [True, True])
+ np.testing.assert_array_equal(mask, expected_mask)
+
+
+def test_repeated_interval_no_timestep_fails_for_timedelta_lengths():
+ interval = RepeatedInterval(interval_length="1d", block_length="7d", start="0d")
+ with pytest.raises(ValueError):
+ interval.get_boolean_mask(length=20)
+
+
+@pytest.mark.parametrize("timestep", ["2h", "150m", "5h", "12h"])
+def test_invalid_timesteps(timestep):
+ """
+ Test that timesteps that don't evenly divide into some or all
+ arguments raise a ValueError
+ """
+ timestep = pd.to_timedelta(timestep)
+ with pytest.raises(ValueError):
+ RepeatedInterval(
+ interval_length="5h", start="4h", block_length="10h"
+ ).get_boolean_mask(length=18, timestep=timestep)
diff --git a/fme/fme/core/data_loading/test_metadata.py b/fme/fme/ace/data_loading/test_metadata.py
similarity index 57%
rename from fme/fme/core/data_loading/test_metadata.py
rename to fme/fme/ace/data_loading/test_metadata.py
index 702e7a3..19f4e94 100644
--- a/fme/fme/core/data_loading/test_metadata.py
+++ b/fme/fme/ace/data_loading/test_metadata.py
@@ -6,32 +6,11 @@
import pytest
import xarray as xr
-from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig
-from fme.core.data_loading.data_typing import VariableMetadata
-from fme.core.data_loading.getters import get_data_loader
-from fme.core.data_loading.requirements import DataRequirements
-
-METADATA = [
- pytest.param(
- {"bar": None},
- id="one_var_no_metadata",
- ),
- pytest.param(
- {"bar": VariableMetadata("km", "bar_long_name")},
- id="one_var_metadata",
- ),
- pytest.param(
- {"foo": VariableMetadata("m", "foo_long_name"), "bar": None},
- id="two_vars_one_metadata",
- ),
- pytest.param(
- {
- "foo": VariableMetadata("m", "foo_long_name"),
- "bar": VariableMetadata("km", "bar_long_name"),
- },
- id="two_vars_two_metadata",
- ),
-]
+from fme.ace.data_loading.config import DataLoaderConfig
+from fme.ace.data_loading.getters import get_data_loader
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.dataset.requirements import DataRequirements
def _coord_value(name, size):
@@ -47,18 +26,18 @@ def _coord_value(name, size):
def _save_netcdf(
filename,
- metadata: Mapping[str, Optional[VariableMetadata]],
+ variable_metadata: Mapping[str, Optional[VariableMetadata]],
num_members=1,
dim_sizes=None,
):
if dim_sizes is None:
dim_sizes = {"time": 3, "grid_yt": 16, "grid_xt": 32}
data_vars = {}
- for name in metadata:
+ for name in variable_metadata:
data = np.random.randn(*list(dim_sizes.values()))
if len(dim_sizes) > 0:
data = data.astype(np.float32)
- item_metadata = metadata[name]
+ item_metadata = variable_metadata[name]
if item_metadata is None:
attrs = {}
else:
@@ -84,24 +63,49 @@ def _save_netcdf(
@pytest.mark.parametrize("n_ensemble_members", [1, 2])
-@pytest.mark.parametrize("metadata", METADATA)
-def test_metadata(tmp_path, metadata, n_ensemble_members):
+@pytest.mark.parametrize(
+ "variable_metadata",
+ [
+ pytest.param(
+ {"bar": None},
+ id="one_var_no_metadata",
+ ),
+ pytest.param(
+ {"bar": VariableMetadata("km", "bar_long_name")},
+ id="one_var_metadata",
+ ),
+ pytest.param(
+ {"foo": VariableMetadata("m", "foo_long_name"), "bar": None},
+ id="two_vars_one_metadata",
+ ),
+ pytest.param(
+ {
+ "foo": VariableMetadata("m", "foo_long_name"),
+ "bar": VariableMetadata("km", "bar_long_name"),
+ },
+ id="two_vars_two_metadata",
+ ),
+ ],
+)
+def test_metadata(tmp_path, variable_metadata, n_ensemble_members):
paths = []
for i in range(n_ensemble_members):
path = tmp_path / f"ic{i}"
path.mkdir(exist_ok=True)
paths.append(path)
- _save_netcdf(path / "data.nc", metadata)
+ _save_netcdf(path / "data.nc", variable_metadata)
config = DataLoaderConfig(
[XarrayDataConfig(data_path=str(path)) for path in paths],
batch_size=1,
num_data_workers=0,
)
- var_names = list(metadata.keys())
+ var_names = list(variable_metadata.keys())
requirements = DataRequirements(names=var_names, n_timesteps=2)
data = get_data_loader(config=config, train=True, requirements=requirements)
target_metadata = {
- name: metadata[name] for name in metadata if metadata[name] is not None
+ name: variable_metadata[name]
+ for name in variable_metadata
+ if variable_metadata[name] is not None
}
- assert data.metadata == target_metadata # type: ignore
+ assert data.variable_metadata == target_metadata # type: ignore
diff --git a/fme/fme/ace/data_loading/test_perturbation.py b/fme/fme/ace/data_loading/test_perturbation.py
new file mode 100644
index 0000000..296d14b
--- /dev/null
+++ b/fme/fme/ace/data_loading/test_perturbation.py
@@ -0,0 +1,54 @@
+import torch
+
+import fme
+from fme.ace.data_loading.perturbation import (
+ ConstantConfig,
+ GreensFunctionConfig,
+ PerturbationSelector,
+)
+
+
+def test_constant_perturbation_config():
+ selector = PerturbationSelector(
+ type="constant",
+ config={"amplitude": 1.0},
+ )
+ perturbation = selector.build()
+ assert isinstance(perturbation, ConstantConfig)
+ assert perturbation.amplitude == 1.0
+ nx, ny = 5, 5
+ lat = torch.arange(nx, device=fme.get_device())
+ lon = torch.arange(ny, device=fme.get_device())
+ lats, lons = torch.meshgrid(lat, lon, indexing="ij")
+ ocean_fraction = torch.ones(nx, ny, device=fme.get_device())
+ data = torch.ones(nx, ny, device=fme.get_device())
+ expected = 2.0 * torch.ones(nx, ny, device=fme.get_device())
+ perturbation.apply_perturbation(data, lats, lons, ocean_fraction)
+ torch.testing.assert_close(data, expected)
+
+
+def test_green_function_perturbation_config():
+ selector = PerturbationSelector(
+ type="greens_function",
+ config={
+ "amplitude": 1.0,
+ "lat_center": 0.0,
+ "lon_center": 0.0,
+ "lat_width": 10.0,
+ "lon_width": 10.0,
+ },
+ )
+ perturbation = selector.build()
+ assert isinstance(perturbation, GreensFunctionConfig)
+ assert perturbation.amplitude == 1.0
+ assert perturbation.lat_center == 0.0
+ assert perturbation.lon_center == 0.0
+ assert perturbation.lat_width == 10.0
+ assert perturbation.lon_width == 10.0
+ nx, ny = 5, 5
+ lat = torch.arange(nx, device=fme.get_device())
+ lon = torch.arange(ny, device=fme.get_device())
+ lats, lons = torch.meshgrid(lat, lon, indexing="ij")
+ ocean_fraction = torch.ones(nx, ny, device=fme.get_device())
+ data = torch.ones(nx, ny, device=fme.get_device())
+ perturbation.apply_perturbation(data, lats, lons, ocean_fraction)
diff --git a/fme/fme/ace/inference/__init__.py b/fme/fme/ace/inference/__init__.py
index 20dba26..e69de29 100644
--- a/fme/fme/ace/inference/__init__.py
+++ b/fme/fme/ace/inference/__init__.py
@@ -1 +0,0 @@
-from .loop import run_inference_evaluator
diff --git a/fme/fme/ace/inference/__main__.py b/fme/fme/ace/inference/__main__.py
index 8825044..e9ccddd 100644
--- a/fme/fme/ace/inference/__main__.py
+++ b/fme/fme/ace/inference/__main__.py
@@ -4,5 +4,13 @@
parser = argparse.ArgumentParser()
parser.add_argument("yaml_config", type=str)
+parser.add_argument(
+ "--segments",
+ type=int,
+ default=None,
+ help="If provided, number of times to repeat the inference in time, saving each "
+ "segment in a separate folder labeled as 'segment_0000', 'segment_0001' etc. "
+ "WARNING: this feature is experimental and its API is subject to change.",
+)
args = parser.parse_args()
-main(yaml_config=args.yaml_config)
+main(yaml_config=args.yaml_config, segments=args.segments)
diff --git a/fme/fme/ace/inference/data_writer/__init__.py b/fme/fme/ace/inference/data_writer/__init__.py
index 1749072..641c6f8 100644
--- a/fme/fme/ace/inference/data_writer/__init__.py
+++ b/fme/fme/ace/inference/data_writer/__init__.py
@@ -1 +1 @@
-from .main import DataWriter, DataWriterConfig, NullDataWriter, PairedDataWriter
+from .main import DataWriter, DataWriterConfig, PairedDataWriter
diff --git a/fme/fme/ace/inference/data_writer/histograms.py b/fme/fme/ace/inference/data_writer/histograms.py
index 7d2cc5b..c87057c 100644
--- a/fme/fme/ace/inference/data_writer/histograms.py
+++ b/fme/fme/ace/inference/data_writer/histograms.py
@@ -5,7 +5,7 @@
import torch
import xarray as xr
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
from fme.core.histogram import DynamicHistogram
@@ -73,21 +73,21 @@ def __init__(
self,
path: str,
n_timesteps: int,
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
save_names: Optional[Sequence[str]],
):
self._target_writer = HistogramDataWriter(
path=path,
n_timesteps=n_timesteps,
filename="histograms_target.nc",
- metadata=metadata,
+ variable_metadata=variable_metadata,
save_names=save_names,
)
self._prediction_writer = HistogramDataWriter(
path=path,
n_timesteps=n_timesteps,
filename="histograms_prediction.nc",
- metadata=metadata,
+ variable_metadata=variable_metadata,
save_names=save_names,
)
@@ -96,17 +96,17 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
self._target_writer.append_batch(
data=target,
start_timestep=start_timestep,
- batch_times=batch_times,
+ batch_time=batch_time,
)
self._prediction_writer.append_batch(
data=prediction,
start_timestep=start_timestep,
- batch_times=batch_times,
+ batch_time=batch_time,
)
def flush(self):
@@ -124,36 +124,37 @@ def __init__(
path: str,
n_timesteps: int,
filename: str,
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
save_names: Optional[Sequence[str]],
):
"""
Args:
- path: Path to write netCDF file(s).
+ path: The directory within which to write the file.
n_timesteps: Number of timesteps to write to the file.
- metadata: Metadata for each variable to be written to the file.
+ filename: Name of the file to write.
+ variable_metadata: Metadata for each variable to be written to the file.
+ save_names: Names of variables to save. If None, all variables are saved.
"""
self.path = path
self._metrics_filename = str(Path(path) / filename)
- self.metadata = metadata
+ self.variable_metadata = variable_metadata
self._histogram = _HistogramAggregator(n_times=n_timesteps, names=save_names)
def append_batch(
self,
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
"""
Append a batch of data to the file.
Args:
- target: Target data.
- prediction: Prediction data.
+ data: The data to write.
start_timestep: Timestep at which to start writing.
- batch_times: Time coordinates for each sample in the batch.
+ batch_time: Time coordinate for each sample in the batch.
"""
- del batch_times
+ del batch_time
self._histogram.record_batch(
data=data,
i_time_start=start_timestep,
@@ -164,11 +165,11 @@ def flush(self):
Flush the data to disk.
"""
metric_dataset = self._histogram.get_dataset()
- for name in self.metadata:
+ for name in self.variable_metadata:
try:
- metric_dataset[f"{name}_bin_edges"].attrs["units"] = self.metadata[
- name
- ].units
+ metric_dataset[f"{name}_bin_edges"].attrs["units"] = (
+ self.variable_metadata[name].units
+ )
except KeyError:
logging.info(
f"{name} in metadata but not in data written to "
@@ -176,9 +177,9 @@ def flush(self):
)
for name in metric_dataset.data_vars:
if not name.endswith("_bin_edges"):
- metric_dataset[f"{name}_bin_edges"].attrs[
- "long_name"
- ] = f"{name} bin edges"
+ metric_dataset[f"{name}_bin_edges"].attrs["long_name"] = (
+ f"{name} bin edges"
+ )
metric_dataset[name].attrs["units"] = "count"
metric_dataset[name].attrs["long_name"] = f"{name} histogram"
metric_dataset.to_netcdf(self._metrics_filename)
diff --git a/fme/fme/ace/inference/data_writer/main.py b/fme/fme/ace/inference/data_writer/main.py
index b70aa1c..df908dd 100644
--- a/fme/fme/ace/inference/data_writer/main.py
+++ b/fme/fme/ace/inference/data_writer/main.py
@@ -2,18 +2,19 @@
import datetime
import warnings
from pathlib import Path
-from typing import Dict, List, Mapping, Optional, Sequence, Union
+from typing import List, Mapping, Optional, Sequence, Union
import numpy as np
import torch
import xarray as xr
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.generics.writer import WriterABC
from .histograms import PairedHistogramDataWriter
from .monthly import MonthlyDataWriter, PairedMonthlyDataWriter, months_for_timesteps
from .raw import PairedRawDataWriter, RawDataWriter
-from .restart import PairedRestartWriter, RestartWriter
from .time_coarsen import PairedTimeCoarsen, TimeCoarsen, TimeCoarsenConfig
from .video import PairedVideoDataWriter
@@ -23,10 +24,9 @@
PairedHistogramDataWriter,
PairedTimeCoarsen,
PairedMonthlyDataWriter,
- PairedRestartWriter,
]
-Subwriter = Union[MonthlyDataWriter, RawDataWriter, RestartWriter, TimeCoarsen]
+Subwriter = Union[MonthlyDataWriter, RawDataWriter, TimeCoarsen]
@dataclasses.dataclass
@@ -34,7 +34,7 @@ class DataWriterConfig:
"""
Configuration for inference data writers.
- Attributes:
+ Parameters:
log_extended_video_netcdfs: Whether to enable writing of netCDF files
containing video metrics.
save_prediction_files: Whether to enable writing of netCDF files
@@ -73,25 +73,23 @@ def __post_init__(self):
def build_paired(
self,
experiment_dir: str,
- n_samples: int,
+ n_initial_conditions: int,
n_timesteps: int,
timestep: datetime.timedelta,
- prognostic_names: Sequence[str],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
) -> "PairedDataWriter":
return PairedDataWriter(
path=experiment_dir,
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
n_timesteps=n_timesteps,
timestep=timestep,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
enable_prediction_netcdfs=self.save_prediction_files,
enable_monthly_netcdfs=self.save_monthly_files,
enable_video_netcdfs=self.log_extended_video_netcdfs,
save_names=self.names,
- prognostic_names=prognostic_names,
enable_histogram_netcdfs=self.save_histogram_files,
time_coarsen=self.time_coarsen,
)
@@ -99,11 +97,10 @@ def build_paired(
def build(
self,
experiment_dir: str,
- n_samples: int,
+ n_initial_conditions: int,
n_timesteps: int,
timestep: datetime.timedelta,
- prognostic_names: Sequence[str],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
) -> "DataWriter":
if self.save_histogram_files:
@@ -118,43 +115,42 @@ def build(
)
return DataWriter(
path=experiment_dir,
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
n_timesteps=n_timesteps,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
timestep=timestep,
enable_prediction_netcdfs=self.save_prediction_files,
enable_monthly_netcdfs=self.save_monthly_files,
save_names=self.names,
- prognostic_names=prognostic_names,
time_coarsen=self.time_coarsen,
)
-class PairedDataWriter:
+class PairedDataWriter(WriterABC[PrognosticState, PairedData]):
def __init__(
self,
path: str,
- n_samples: int,
+ n_initial_conditions: int,
n_timesteps: int,
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
timestep: datetime.timedelta,
enable_prediction_netcdfs: bool,
enable_monthly_netcdfs: bool,
enable_video_netcdfs: bool,
save_names: Optional[Sequence[str]],
- prognostic_names: Sequence[str],
enable_histogram_netcdfs: bool,
time_coarsen: Optional[TimeCoarsenConfig] = None,
):
"""
Args:
path: Path to write netCDF file(s).
- n_samples: Number of samples to write to the file.
+ n_initial_conditions: Number of ICs/ensemble members to write to the file.
n_timesteps: Number of timesteps to write to the file.
- metadata: Metadata for each variable to be written to the file.
+ variable_metadata: Metadata for each variable to be written to the file.
coords: Coordinate data to be written to the file.
+ timestep: Timestep of the model.
enable_prediction_netcdfs: Whether to enable writing of netCDF files
containing the predictions and target values.
enable_monthly_netcdfs: Whether to enable writing of netCDF files
@@ -169,8 +165,7 @@ def __init__(
self._writers: List[PairedSubwriter] = []
self.path = path
self.coords = coords
- self.metadata = metadata
- self.prognostic_names = prognostic_names
+ self.variable_metadata = variable_metadata
if time_coarsen is not None:
n_coarsened_timesteps = time_coarsen.n_coarsened_timesteps(n_timesteps)
@@ -187,9 +182,9 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter:
_time_coarsen_builder(
PairedRawDataWriter(
path=path,
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
)
@@ -198,11 +193,11 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter:
self._writers.append(
PairedMonthlyDataWriter(
path=path,
- n_samples=n_samples,
+ n_samples=n_initial_conditions,
n_timesteps=n_timesteps,
timestep=timestep,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
)
@@ -212,7 +207,7 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter:
PairedVideoDataWriter(
path=path,
n_timesteps=n_coarsened_timesteps,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
)
@@ -223,66 +218,46 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter:
PairedHistogramDataWriter(
path=path,
n_timesteps=n_coarsened_timesteps,
- metadata=metadata,
+ variable_metadata=variable_metadata,
save_names=save_names,
)
)
)
- self._writers.append(
- PairedRestartWriter(
- path=path,
- is_restart_step=lambda i: i == n_timesteps - 1,
- prognostic_names=prognostic_names,
- metadata=metadata,
- coords=coords,
- )
- )
+ self._n_timesteps_seen = 0
- def save_initial_condition(
- self,
- ic_data: Dict[str, torch.Tensor],
- ic_time: xr.DataArray,
- ):
- data_arrays = {}
- for name in self.prognostic_names:
- if name not in ic_data:
- raise KeyError(
- f"Initial condition data missing for prognostic variable {name}."
- )
- data = ic_data[name].cpu().numpy()
- data_arrays[name] = xr.DataArray(data, dims=["sample", "lat", "lon"])
- if name in self.metadata:
- data_arrays[name].attrs = {
- "long_name": self.metadata[name].long_name,
- "units": self.metadata[name].units,
- }
- data_arrays["time"] = ic_time
- ds = xr.Dataset(data_arrays, coords=self.coords)
- ds.to_netcdf(str(Path(self.path) / "initial_condition.nc"))
+ def write(self, data: PrognosticState, filename: str):
+ """Eagerly write data to a single netCDF file.
+
+ Args:
+ data: the data to be written.
+ filename: the filename to use for the netCDF file.
+ """
+ _write(
+ data=data.as_batch_data(),
+ path=self.path,
+ filename=filename,
+ variable_metadata=self.variable_metadata,
+ coords=self.coords,
+ )
def append_batch(
self,
- target: Dict[str, torch.Tensor],
- prediction: Dict[str, torch.Tensor],
- start_timestep: int,
- batch_times: xr.DataArray,
+ batch: PairedData,
):
"""
Append a batch of data to the file.
Args:
- target: Target data.
- prediction: Prediction data.
- start_timestep: Timestep at which to start writing.
- batch_times: Time coordinates for each sample in the batch.
+ batch: Prediction and target data.
"""
for writer in self._writers:
writer.append_batch(
- target=target,
- prediction=prediction,
- start_timestep=start_timestep,
- batch_times=batch_times,
+ target=dict(batch.target),
+ prediction=dict(batch.prediction),
+ start_timestep=self._n_timesteps_seen,
+ batch_time=batch.time,
)
+ self._n_timesteps_seen += batch.time.shape[1]
def flush(self):
"""
@@ -292,27 +267,76 @@ def flush(self):
writer.flush()
-class DataWriter:
+def _write(
+ data: BatchData,
+ path: str,
+ filename: str,
+ variable_metadata: Mapping[str, VariableMetadata],
+ coords: Mapping[str, np.ndarray],
+):
+ """Write provided data to a single netCDF at specified path/filename.
+
+ If the data has only one timestep, the data is squeezed to remove
+ the time dimension.
+
+ Args:
+ data: Batch data to written.
+ path: Directory to write the netCDF file in.
+ filename: filename to use for netCDF.
+ variable_metadata: Metadata for each variable to be written to the file.
+ coords: Coordinate data to be written to the file.
+ """
+ if data.time.sizes["time"] == 1:
+ time_dim = data.dims.index("time")
+ dims_to_write = data.dims[:time_dim] + data.dims[time_dim + 1 :]
+
+ def maybe_squeeze(x: torch.Tensor) -> torch.Tensor:
+ return x.squeeze(dim=time_dim)
+
+ time_array = data.time.isel(time=0)
+ else:
+ dims_to_write = data.dims
+
+ def maybe_squeeze(x):
+ return x
+
+ time_array = data.time
+
+ data_arrays = {}
+ for name in data.data:
+ array = maybe_squeeze(data.data[name]).cpu().numpy()
+ data_arrays[name] = xr.DataArray(array, dims=dims_to_write)
+ if name in variable_metadata:
+ data_arrays[name].attrs = {
+ "long_name": variable_metadata[name].long_name,
+ "units": variable_metadata[name].units,
+ }
+ data_arrays["time"] = time_array
+ ds = xr.Dataset(data_arrays, coords=coords)
+ ds.to_netcdf(str(Path(path) / filename))
+
+
+class DataWriter(WriterABC[PrognosticState, BatchData]):
def __init__(
self,
path: str,
- n_samples: int,
+ n_initial_conditions: int,
n_timesteps: int,
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
timestep: datetime.timedelta,
enable_prediction_netcdfs: bool,
enable_monthly_netcdfs: bool,
save_names: Optional[Sequence[str]],
- prognostic_names: Sequence[str],
time_coarsen: Optional[TimeCoarsenConfig] = None,
):
"""
Args:
path: Directory within which to write netCDF file(s).
- n_samples: Number of samples to write to the file.
+ n_initial_conditions: Number of initial conditions / timeseries
+ to write to the file.
n_timesteps: Number of timesteps to write to the file.
- metadata: Metadata for each variable to be written to the file.
+ variable_metadata: Metadata for each variable to be written to the file.
coords: Coordinate data to be written to the file.
timestep: Timestep of the model.
enable_prediction_netcdfs: Whether to enable writing of netCDF files
@@ -323,6 +347,10 @@ def __init__(
time_coarsen: Configuration for time coarsening of raw outputs.
"""
self._writers: List[Subwriter] = []
+ if "face" in coords:
+ # TODO: handle writing HEALPix data
+ # https://github.com/ai2cm/full-model/issues/1089
+ return
def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter:
if time_coarsen is not None:
@@ -335,9 +363,9 @@ def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter:
RawDataWriter(
path=path,
label="autoregressive_predictions.nc",
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
)
@@ -348,40 +376,33 @@ def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter:
MonthlyDataWriter(
path=path,
label="predictions",
- n_samples=n_samples,
+ n_samples=n_initial_conditions,
n_months=months_for_timesteps(n_timesteps, timestep),
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
)
- self._writers.append(
- RestartWriter(
- path=path,
- is_restart_step=lambda i: i == n_timesteps - 1,
- prognostic_names=prognostic_names,
- metadata=metadata,
- coords=coords,
- )
- )
+ self.path = path
+ self.variable_metadata = variable_metadata
+ self.coords = coords
+ self._n_timesteps_seen = 0
- def append_batch(
- self,
- data: Dict[str, torch.Tensor],
- start_timestep: int,
- batch_times: xr.DataArray,
- ):
+ def append_batch(self, batch: BatchData):
"""
- Append a batch of data to the file.
+ Append prediction data to the file.
+
Args:
- data: Data to write.
- start_timestep: Timestep at which to start writing.
- start_sample: Sample at which to start writing.
- batch_times: Time coordinates for each sample in the batch.
+ batch: Data to be written.
"""
for writer in self._writers:
- writer.append_batch(data, start_timestep, batch_times)
+ writer.append_batch(
+ data=dict(batch.data),
+ start_timestep=self._n_timesteps_seen,
+ batch_time=batch.time,
+ )
+ self._n_timesteps_seen += batch.time.shape[1]
def flush(self):
"""
@@ -390,30 +411,11 @@ def flush(self):
for writer in self._writers:
writer.flush()
-
-class NullDataWriter:
- """
- Null pattern for DataWriter, which does nothing.
- """
-
- def __init__(self):
- pass
-
- def append_batch(
- self,
- target: Dict[str, torch.Tensor],
- prediction: Dict[str, torch.Tensor],
- start_timestep: int,
- batch_times: xr.DataArray,
- ):
- pass
-
- def flush(self):
- pass
-
- def save_initial_condition(
- self,
- ic_data: Dict[str, torch.Tensor],
- ic_time: xr.DataArray,
- ):
- pass
+ def write(self, data: PrognosticState, filename: str):
+ _write(
+ data=data.as_batch_data(),
+ path=self.path,
+ filename=filename,
+ variable_metadata=self.variable_metadata,
+ coords=self.coords,
+ )
diff --git a/fme/fme/ace/inference/data_writer/monthly.py b/fme/fme/ace/inference/data_writer/monthly.py
index 6254f38..3b8f609 100644
--- a/fme/fme/ace/inference/data_writer/monthly.py
+++ b/fme/fme/ace/inference/data_writer/monthly.py
@@ -9,8 +9,12 @@
import xarray as xr
from netCDF4 import Dataset
-from fme.ace.inference.data_writer.utils import get_all_names
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.inference.data_writer.utils import (
+ DIM_INFO_HEALPIX,
+ DIM_INFO_LATLON,
+ get_all_names,
+)
+from fme.core.dataset.data_typing import VariableMetadata
LEAD_TIME_DIM = "time"
LEAD_TIME_UNITS = "months"
@@ -41,7 +45,7 @@ def __init__(
n_timesteps: int,
timestep: datetime.timedelta,
save_names: Optional[Sequence[str]],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
):
n_months = months_for_timesteps(n_timesteps, timestep)
@@ -51,7 +55,7 @@ def __init__(
n_samples=n_samples,
n_months=n_months,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
self._prediction_writer = MonthlyDataWriter(
@@ -60,7 +64,7 @@ def __init__(
n_samples=n_samples,
n_months=n_months,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
@@ -69,13 +73,13 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
self._target_writer.append_batch(
- data=target, start_timestep=start_timestep, batch_times=batch_times
+ data=target, start_timestep=start_timestep, batch_time=batch_time
)
self._prediction_writer.append_batch(
- data=prediction, start_timestep=start_timestep, batch_times=batch_times
+ data=prediction, start_timestep=start_timestep, batch_time=batch_time
)
def flush(self):
@@ -97,7 +101,7 @@ def __init__(
n_samples: int,
n_months: int,
save_names: Optional[Sequence[str]],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
):
"""
@@ -109,14 +113,14 @@ def __init__(
n_months: Number of months to write to the file.
save_names: Names of variables to save in the predictions netcdf file.
If None, all predicted variables will be saved.
- metadata: Metadata for each variable to be written to the file.
+ variable_metadata: Metadata for each variable to be written to the file.
coords: Coordinate data to be written to the file.
"""
if label != "":
label = "_" + label
filename = str(Path(path) / f"monthly_mean{label}.nc")
self._save_names = save_names
- self.metadata = metadata
+ self.variable_metadata = variable_metadata
self.coords = coords
self.dataset = Dataset(filename, "w", format="NETCDF4")
self.dataset.createDimension(LEAD_TIME_DIM, n_months)
@@ -137,10 +141,10 @@ def __init__(
)
self.dataset.variables[VALID_TIME].units = TIME_UNITS
self.dataset.variables[COUNTS][:] = 0
- self._n_lat: Optional[int] = None
- self._n_lon: Optional[int] = None
+
self._init_years = np.full([n_samples], -1, dtype=int)
self._init_months = np.full([n_samples], -1, dtype=int)
+ self._dataset_dims_created = False
def _get_initial_year_and_month(
self,
@@ -165,7 +169,7 @@ def _get_initial_year_and_month(
self.dataset.variables[VALID_TIME][:, :] = days_since_reference + 14
return (self._init_years, self._init_months)
- def _get_month_indices(self, batch_times: xr.DataArray) -> np.ndarray:
+ def _get_month_indices(self, batch_time: xr.DataArray) -> np.ndarray:
"""
Get the month indices for the batch of data.
@@ -174,16 +178,16 @@ def _get_month_indices(self, batch_times: xr.DataArray) -> np.ndarray:
indices in this and future calls.
Args:
- batch_times: Time coordinates for each sample in the batch, of shape
+ batch_time: Time coordinate for each sample in the batch, of shape
[ensemble_member, lead_time].
Returns:
Month indices for the batch of data.
"""
- calendar = batch_times.dt.calendar
- years = batch_times.dt.year.values
+ calendar = batch_time.dt.calendar
+ years = batch_time.dt.year.values
# datetime months are 1-indexed, we want 0-indexed
- months = batch_times.dt.month.values - 1
+ months = batch_time.dt.month.values - 1
init_years, init_months = self._get_initial_year_and_month(
years=years[:, 0], months=months[:, 0], calendar=calendar
)
@@ -198,7 +202,7 @@ def append_batch(
self,
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
"""
Append a batch of data to the file.
@@ -206,40 +210,39 @@ def append_batch(
Args:
data: Values to store.
start_timestep: Timestep index for the start of the batch, unused.
- batch_times: Time coordinates for each sample in the batch.
+ batch_time: Time coordinate for each sample in the batch.
"""
del start_timestep # unused
n_samples_data = list(data.values())[0].shape[0]
- n_samples_time = batch_times.sizes["sample"]
+ n_samples_time = batch_time.sizes["sample"]
if n_samples_data != n_samples_time:
raise ValueError(
f"Batch size mismatch, data has {n_samples_data} samples "
- f"and times has {n_samples_time} samples."
+ f"and batch_time has {n_samples_time} samples."
)
n_times_data = list(data.values())[0].shape[1]
- n_times_time = batch_times.sizes["time"]
+ n_times_time = batch_time.sizes["time"]
if n_times_data != n_times_time:
raise ValueError(
f"Batch time dimension mismatch, data has {n_times_data} times "
- f"and times has {n_times_time} times."
+ f"and batch_time has {n_times_time} times."
)
- if self._n_lat is None:
- self._n_lat = data[next(iter(data.keys()))].shape[-2]
- self.dataset.createDimension("lat", self._n_lat)
- if "lat" in self.coords:
- self.dataset.createVariable("lat", "f4", ("lat",))
- self.dataset.variables["lat"][:] = self.coords["lat"]
- if self._n_lon is None:
- self._n_lon = data[next(iter(data.keys()))].shape[-1]
- self.dataset.createDimension("lon", self._n_lon)
- if "lon" in self.coords:
- self.dataset.createVariable("lon", "f4", ("lon",))
- self.dataset.variables["lon"][:] = self.coords["lon"]
-
- dims = (ENSEMBLE_DIM, LEAD_TIME_DIM, "lat", "lon")
+ if not self._dataset_dims_created:
+ _dim_info = DIM_INFO_HEALPIX if "face" in self.coords else DIM_INFO_LATLON
+ _ordered_names = []
+ for dim in _dim_info:
+ dim_size = data[next(iter(data.keys()))].shape[dim.index]
+ self.dataset.createDimension(dim.name, dim_size)
+ if dim.name in self.coords:
+ self.dataset.createVariable(dim.name, "f4", (dim.name,))
+ self.dataset.variables[dim.name][:] = self.coords[dim.name]
+ _ordered_names.append(dim.name)
+ dims = (ENSEMBLE_DIM, LEAD_TIME_DIM, *_ordered_names)
+ self._dataset_dims_created = True
+
save_names = self._get_variable_names_to_save(data.keys())
- months = self._get_month_indices(batch_times)
+ months = self._get_month_indices(batch_time)
month_min = np.min(months)
month_range = np.max(months) - month_min + 1
count_data = self.dataset.variables[COUNTS][
@@ -255,13 +258,13 @@ def append_batch(
fill_value=np.nan,
)
self.dataset.variables[variable_name][:] = 0.0
- if variable_name in self.metadata:
- self.dataset.variables[variable_name].units = self.metadata[
+ if variable_name in self.variable_metadata:
+ self.dataset.variables[
variable_name
- ].units
- self.dataset.variables[variable_name].long_name = self.metadata[
+ ].units = self.variable_metadata[variable_name].units
+ self.dataset.variables[
variable_name
- ].long_name
+ ].long_name = self.variable_metadata[variable_name].long_name
self.dataset.variables[variable_name].coordinates = " ".join(
[INIT_TIME, VALID_TIME]
)
@@ -385,5 +388,7 @@ def get_days_since_reference(
freq="MS",
calendar=calendar,
)
- days_since_reference[i, :] = (dates_sample - reference_date).days
+ days_since_reference[i, :] = (
+ dates_sample.values - reference_date
+ ) // datetime.timedelta(days=1)
return days_since_reference
diff --git a/fme/fme/ace/inference/data_writer/raw.py b/fme/fme/ace/inference/data_writer/raw.py
index b4f37b8..5026957 100644
--- a/fme/fme/ace/inference/data_writer/raw.py
+++ b/fme/fme/ace/inference/data_writer/raw.py
@@ -9,12 +9,16 @@
import xarray as xr
from netCDF4 import Dataset
-from fme.ace.inference.data_writer.utils import get_all_names
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.inference.data_writer.utils import (
+ DIM_INFO_HEALPIX,
+ DIM_INFO_LATLON,
+ get_all_names,
+)
+from fme.core.dataset.data_typing import VariableMetadata
LEAD_TIME_DIM = "time"
LEAD_TIME_UNITS = "microseconds"
-SAMPLE_DIM = "sample"
+IC_DIM = "sample"
INIT_TIME = "init_time"
INIT_TIME_UNITS = "microseconds since 1970-01-01 00:00:00"
VALID_TIME = "valid_time"
@@ -30,25 +34,25 @@ class PairedRawDataWriter:
def __init__(
self,
path: str,
- n_samples: int,
+ n_initial_conditions: int,
save_names: Optional[Sequence[str]],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
):
self._target_writer = RawDataWriter(
path=path,
label="autoregressive_target.nc",
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
self._prediction_writer = RawDataWriter(
path=path,
label="autoregressive_predictions.nc",
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
save_names=save_names,
- metadata=metadata,
+ variable_metadata=variable_metadata,
coords=coords,
)
@@ -57,17 +61,17 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
self._target_writer.append_batch(
data=target,
start_timestep=start_timestep,
- batch_times=batch_times,
+ batch_time=batch_time,
)
self._prediction_writer.append_batch(
data=prediction,
start_timestep=start_timestep,
- batch_times=batch_times,
+ batch_time=batch_time,
)
def flush(self):
@@ -84,36 +88,36 @@ def __init__(
self,
path: str,
label: str,
- n_samples: int,
+ n_initial_conditions: int,
save_names: Optional[Sequence[str]],
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
):
"""
Args:
- filename: Path to write netCDF file(s).
- n_samples: Number of samples to write to the file. This might correspond
- to a number of initial conditions, or some other grouping of samples.
+ path: Directory within which to write the file.
+ label: Name of the file to write.
+ n_initial_conditions: Number of initial conditions / timeseries
+ to write to the file.
save_names: Names of variables to save in the output file.
If None, all provided variables will be saved.
- metadata: Metadata for each variable to be written to the file.
+ variable_metadata: Metadata for each variable to be written to the file.
coords: Coordinate data to be written to the file.
"""
filename = str(Path(path) / label)
self._save_names = save_names
- self.metadata = metadata
+ self.variable_metadata = variable_metadata
self.coords = coords
self.dataset = Dataset(filename, "w", format="NETCDF4")
self.dataset.createDimension(LEAD_TIME_DIM, None) # unlimited dimension
self.dataset.createVariable(LEAD_TIME_DIM, "i8", (LEAD_TIME_DIM,))
self.dataset.variables[LEAD_TIME_DIM].units = LEAD_TIME_UNITS
- self.dataset.createDimension(SAMPLE_DIM, n_samples)
- self.dataset.createVariable(INIT_TIME, "i8", (SAMPLE_DIM,))
+ self.dataset.createDimension(IC_DIM, n_initial_conditions)
+ self.dataset.createVariable(INIT_TIME, "i8", (IC_DIM,))
self.dataset.variables[INIT_TIME].units = INIT_TIME_UNITS
- self.dataset.createVariable(VALID_TIME, "i8", (SAMPLE_DIM, LEAD_TIME_DIM))
+ self.dataset.createVariable(VALID_TIME, "i8", (IC_DIM, LEAD_TIME_DIM))
self.dataset.variables[VALID_TIME].units = INIT_TIME_UNITS
- self._n_lat: Optional[int] = None
- self._n_lon: Optional[int] = None
+ self._dataset_dims_created = False
def _get_variable_names_to_save(
self, *data_varnames: Iterable[str]
@@ -124,7 +128,7 @@ def append_batch(
self,
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
"""
Append a batch of data to the file.
@@ -132,37 +136,38 @@ def append_batch(
Args:
data: Data to be written to file.
start_timestep: Timestep (lead time dim) at which to start writing.
- batch_times: Time coordinates for each sample in the batch.
+ batch_time: Time coordinate for each sample in the batch.
"""
+ if self.dataset is None:
+ return
n_samples_data = list(data.values())[0].shape[0]
- n_samples_time = batch_times.sizes["sample"]
+ n_samples_time = batch_time.sizes["sample"]
if n_samples_data != n_samples_time:
raise ValueError(
f"Batch size mismatch, data has {n_samples_data} samples "
f"and times has {n_samples_time} samples."
)
n_times_data = list(data.values())[0].shape[1]
- n_times_time = batch_times.sizes["time"]
+ n_times_time = batch_time.sizes["time"]
if n_times_data != n_times_time:
raise ValueError(
f"Batch time dimension mismatch, data has {n_times_data} times "
- f"and times has {n_times_time} times."
+ f"and time has {n_times_time} times."
)
- if self._n_lat is None:
- self._n_lat = data[next(iter(data.keys()))].shape[-2]
- self.dataset.createDimension("lat", self._n_lat)
- if "lat" in self.coords:
- self.dataset.createVariable("lat", "f4", ("lat",))
- self.dataset.variables["lat"][:] = self.coords["lat"]
- if self._n_lon is None:
- self._n_lon = data[next(iter(data.keys()))].shape[-1]
- self.dataset.createDimension("lon", self._n_lon)
- if "lon" in self.coords:
- self.dataset.createVariable("lon", "f4", ("lon",))
- self.dataset.variables["lon"][:] = self.coords["lon"]
+ if not self._dataset_dims_created:
+ _dim_info = DIM_INFO_HEALPIX if "face" in self.coords else DIM_INFO_LATLON
+ _ordered_names = []
+ for dim in _dim_info:
+ dim_size = data[next(iter(data.keys()))].shape[dim.index]
+ self.dataset.createDimension(dim.name, dim_size)
+ if dim.name in self.coords:
+ self.dataset.createVariable(dim.name, "f4", (dim.name,))
+ self.dataset.variables[dim.name][:] = self.coords[dim.name]
+ _ordered_names.append(dim.name)
+ dims = (IC_DIM, LEAD_TIME_DIM, *_ordered_names)
+ self._dataset_dims_created = True
- dims = (SAMPLE_DIM, LEAD_TIME_DIM, "lat", "lon")
save_names = self._get_variable_names_to_save(data.keys())
for variable_name in save_names:
# define the variable if it doesn't exist
@@ -173,13 +178,13 @@ def append_batch(
dims,
fill_value=np.nan,
)
- if variable_name in self.metadata:
- self.dataset.variables[variable_name].units = self.metadata[
+ if variable_name in self.variable_metadata:
+ self.dataset.variables[
variable_name
- ].units
- self.dataset.variables[variable_name].long_name = self.metadata[
+ ].units = self.variable_metadata[variable_name].units
+ self.dataset.variables[
variable_name
- ].long_name
+ ].long_name = self.variable_metadata[variable_name].long_name
self.dataset.variables[variable_name].coordinates = " ".join(
[INIT_TIME, VALID_TIME]
)
@@ -194,12 +199,12 @@ def append_batch(
# handle time dimensions
if not hasattr(self.dataset.variables[INIT_TIME], "calendar"):
- self.dataset.variables[INIT_TIME].calendar = batch_times.dt.calendar
+ self.dataset.variables[INIT_TIME].calendar = batch_time.dt.calendar
if not hasattr(self.dataset.variables[VALID_TIME], "calendar"):
- self.dataset.variables[VALID_TIME].calendar = batch_times.dt.calendar
+ self.dataset.variables[VALID_TIME].calendar = batch_time.dt.calendar
if start_timestep == 0:
- init_times: np.ndarray = batch_times.isel(time=0).values
+ init_times: np.ndarray = batch_time.isel(time=0).values
init_times_numeric: np.ndarray = cftime.date2num(
init_times,
units=self.dataset.variables[INIT_TIME].units,
@@ -216,22 +221,22 @@ def append_batch(
units=self.dataset.variables[INIT_TIME].units,
calendar=self.dataset.variables[INIT_TIME].calendar,
)
- lead_times_microseconds = get_batch_lead_times_microseconds(
+ lead_time_microseconds = get_batch_lead_time_microseconds(
init_times,
- batch_times.values,
+ batch_time.values,
)
self.dataset.variables[LEAD_TIME_DIM][
- start_timestep : start_timestep + lead_times_microseconds.shape[0]
- ] = lead_times_microseconds
+ start_timestep : start_timestep + lead_time_microseconds.shape[0]
+ ] = lead_time_microseconds
valid_times_numeric: np.ndarray = cftime.date2num(
- batch_times.values,
+ batch_time.values,
units=self.dataset.variables[VALID_TIME].units,
calendar=self.dataset.variables[VALID_TIME].calendar,
)
self.dataset.variables[VALID_TIME][
:,
- start_timestep : start_timestep + lead_times_microseconds.shape[0],
+ start_timestep : start_timestep + lead_time_microseconds.shape[0],
] = valid_times_numeric
self.dataset.sync() # Flush the data to disk
@@ -243,24 +248,24 @@ def flush(self):
self.dataset.sync()
-def get_batch_lead_times_microseconds(
- init_times: npt.NDArray[cftime.datetime], batch_times: npt.NDArray[cftime.datetime]
+def get_batch_lead_time_microseconds(
+ init_time: npt.NDArray[cftime.datetime], batch_time: npt.NDArray[cftime.datetime]
) -> npt.NDArray[np.int64]:
"""
- Get the lead times in seconds for the batch.
+ Get the lead time in seconds for the batch.
Assert that they are the same for each sample.
Args:
- init_times: Initialization time for each sample in the batch.
- batch_times: Full array of times for each sample in the batch.
+ init_time: Initialization time for each sample in the batch.
+ batch_time: Array of time coordinates for each sample in the batch.
Returns:
- Lead times in microseconds for the batch
+ Lead time in microseconds for the batch
"""
- if init_times.shape[0] != batch_times.shape[0]:
+ if init_time.shape[0] != batch_time.shape[0]:
raise ValueError(
- f"Number of init times ({len(init_times)}) must "
- f"match number of batch times ({len(batch_times)})"
+ f"Number of init times ({len(init_time)}) must "
+ f"match number of batch times ({len(batch_time)})"
)
# Carry out timedelta arithmetic in NumPy arrays to avoid xarray's automatic
# casting of datetime.timedelta objects to timedelta64[ns] values, which would
@@ -268,12 +273,12 @@ def get_batch_lead_times_microseconds(
# ~292 years. See
# https://numpy.org/doc/stable/reference/arrays.datetime.html#datetime-units
# for more details on the limits of various precision timedeltas.
- lead_times: npt.NDArray[datetime.timedelta] = ( # type: ignore
- batch_times - init_times[:, None]
+ lead_time: npt.NDArray[datetime.timedelta] = ( # type: ignore
+ batch_time - init_time[:, None]
)
- lead_times_microseconds: npt.NDArray[np.int64] = (
- lead_times // datetime.timedelta(microseconds=1)
+ lead_time_microseconds: npt.NDArray[np.int64] = (
+ lead_time // datetime.timedelta(microseconds=1)
).astype(np.int64)
- if not np.all(lead_times_microseconds == lead_times_microseconds[0, :]):
+ if not np.all(lead_time_microseconds == lead_time_microseconds[0, :]):
raise ValueError("Lead times are not the same for each sample in the batch.")
- return lead_times_microseconds[0, :]
+ return lead_time_microseconds[0, :]
diff --git a/fme/fme/ace/inference/data_writer/test_data_writer.py b/fme/fme/ace/inference/data_writer/test_data_writer.py
index 00b9286..da083d8 100644
--- a/fme/fme/ace/inference/data_writer/test_data_writer.py
+++ b/fme/fme/ace/inference/data_writer/test_data_writer.py
@@ -8,13 +8,15 @@
import xarray as xr
from netCDF4 import Dataset
+from fme.ace.data_loading.batch_data import BatchData, PairedData
from fme.ace.inference.data_writer.main import (
DataWriter,
DataWriterConfig,
PairedDataWriter,
)
-from fme.ace.inference.data_writer.raw import get_batch_lead_times_microseconds
+from fme.ace.inference.data_writer.raw import get_batch_lead_time_microseconds
from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig
+from fme.core.device import get_device
CALENDAR_CFTIME = {
"julian": cftime.DatetimeJulian,
@@ -34,8 +36,6 @@ def test_data_writer_config_save_names():
save_monthly_files=False,
save_histogram_files=False,
)
- with pytest.warns():
- DataWriterConfig(names=variable_names, **kwargs)
for save_writer in [
"save_prediction_files",
"save_monthly_files",
@@ -43,7 +43,7 @@ def test_data_writer_config_save_names():
]:
kwargs_copy = kwargs.copy()
kwargs_copy.update({save_writer: True})
- DataWriterConfig(names=variable_names, **kwargs_copy)
+ DataWriterConfig(names=variable_names, **kwargs_copy) # type: ignore
class TestDataWriter:
@@ -63,11 +63,13 @@ def calendar(self, request):
"""
return request.param
- def get_batch_times(self, start_time, end_time, freq, n_samples, calendar="julian"):
+ def get_batch_time(
+ self, start_time, end_time, freq, n_initial_conditions, calendar="julian"
+ ):
datetime_class = CALENDAR_CFTIME[calendar]
start_time = datetime_class(*start_time)
end_time = datetime_class(*end_time)
- batch_times = xr.DataArray(
+ batch_time = xr.DataArray(
xr.cftime_range(
start_time,
end_time,
@@ -76,7 +78,9 @@ def get_batch_times(self, start_time, end_time, freq, n_samples, calendar="julia
).values,
dims="time",
)
- return xr.concat([batch_times for _ in range(n_samples)], dim="sample")
+ return xr.concat(
+ [batch_time for _ in range(n_initial_conditions)], dim="sample"
+ )
@pytest.fixture
def sample_metadata(self):
@@ -86,30 +90,54 @@ def sample_metadata(self):
}
@pytest.fixture
- def sample_target_data(self):
+ def sample_target_data(self, request):
+ shape = request.param
data = {
- # sample, time, lat, lon
- "temp": torch.rand((2, 3, 4, 5)),
- "humidity": torch.rand((2, 3, 4, 5)), # input-only variable
+ # sample, time, *horizontal_dims
+ "temp": torch.rand(shape),
+ "humidity": torch.rand(shape), # input-only variable
"pressure": torch.rand(
- (2, 3, 4, 5)
+ shape
), # Extra variable for which there's no metadata
- "precipitation": torch.rand((2, 3, 4, 5)), # optionally saved
+ "precipitation": torch.rand(shape), # optionally saved
}
return data
@pytest.fixture
- def sample_prediction_data(self):
+ def sample_prediction_data(self, request):
+ shape = request.param
data = {
# sample, time, lat, lon
- "temp": torch.rand((2, 3, 4, 5)),
+ "temp": torch.rand(shape),
"pressure": torch.rand(
- (2, 3, 4, 5)
+ shape
), # Extra variable for which there's no metadata
- "precipitation": torch.rand((2, 3, 4, 5)), # optionally saved
+ "precipitation": torch.rand(shape), # optionally saved
}
return data
+ @pytest.fixture
+ def coords(self, request):
+ return request.param
+
+ @pytest.mark.parametrize(
+ "sample_target_data, sample_prediction_data, coords",
+ [
+ pytest.param(
+ (2, 3, 4, 5),
+ (2, 3, 4, 5),
+ {"lat": np.arange(4), "lon": np.arange(5)},
+ id="LatLon",
+ ),
+ pytest.param(
+ (2, 3, 6, 4, 5),
+ (2, 3, 6, 4, 5),
+ {"face": np.arange(6), "height": np.arange(4), "width": np.arange(5)},
+ id="HEALPix",
+ ),
+ ],
+ indirect=True,
+ )
def test_append_batch(
self,
sample_metadata,
@@ -117,52 +145,54 @@ def test_append_batch(
sample_prediction_data,
tmp_path,
calendar,
+ coords,
):
- n_samples = 2
+ n_initial_conditions = 2
n_timesteps = 6
writer = PairedDataWriter(
str(tmp_path),
- n_samples=n_samples,
+ n_initial_conditions=n_initial_conditions,
n_timesteps=n_timesteps,
timestep=TIMESTEP,
- metadata=sample_metadata,
- coords={"lat": np.arange(4), "lon": np.arange(5)},
+ variable_metadata=sample_metadata,
+ coords=coords,
enable_prediction_netcdfs=True,
enable_video_netcdfs=False,
enable_monthly_netcdfs=True,
enable_histogram_netcdfs=True,
save_names=None,
- prognostic_names=[],
)
start_time = (2020, 1, 1, 0, 0, 0)
end_time = (2020, 1, 1, 12, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time,
end_time=end_time,
- freq="6H",
- n_samples=n_samples,
+ freq="6h",
+ n_initial_conditions=n_initial_conditions,
calendar=calendar,
)
writer.append_batch(
- sample_target_data,
- sample_prediction_data,
- start_timestep=0,
- batch_times=batch_times,
+ batch=PairedData(
+ prediction=sample_prediction_data,
+ target=sample_target_data,
+ time=batch_time,
+ ),
)
start_time_2 = (2020, 1, 1, 18, 0, 0)
end_time_2 = (2020, 1, 2, 6, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time_2,
end_time=end_time_2,
- freq="6H",
- n_samples=n_samples,
+ freq="6h",
+ n_initial_conditions=n_initial_conditions,
calendar=calendar,
)
writer.append_batch(
- sample_target_data,
- sample_prediction_data,
- start_timestep=3,
- batch_times=batch_times,
+ batch=PairedData(
+ prediction=sample_prediction_data,
+ target=sample_target_data,
+ time=batch_time,
+ ),
)
writer.flush()
@@ -171,14 +201,14 @@ def test_append_batch(
assert dataset["time"].units == "microseconds"
assert dataset["init_time"].units == "microseconds since 1970-01-01 00:00:00"
assert dataset["init_time"].calendar == calendar
+ horizontal_shape = (4, 5) if "lat" in coords else (6, 4, 5)
for var_name in set(sample_prediction_data.keys()):
var_data = dataset.variables[var_name][:]
assert var_data.shape == (
- n_samples,
+ n_initial_conditions,
n_timesteps,
- 4,
- 5,
- ) # sample, time, lat, lon
+ *horizontal_shape,
+ )
assert not np.isnan(var_data).any(), "unexpected NaNs in prediction data"
if var_name in sample_metadata:
assert (
@@ -196,7 +226,7 @@ def test_append_batch(
# Open the target output file and do smaller set of checks
dataset = Dataset(tmp_path / "autoregressive_target.nc", "r")
- coord_names = {"time", "init_time", "valid_time", "lat", "lon"}
+ coord_names = {"time", "init_time", "valid_time", *set(coords)}
assert set(dataset.variables) == set(sample_target_data) | coord_names
# Open the file again with xarray and check the time coordinates,
@@ -222,7 +252,10 @@ def test_append_batch(
)
xr.testing.assert_equal(ds["time"], expected_lead_times)
expected_init_times = xr.DataArray(
- [CALENDAR_CFTIME[calendar](*start_time) for _ in range(n_samples)],
+ [
+ CALENDAR_CFTIME[calendar](*start_time)
+ for _ in range(n_initial_conditions)
+ ],
dims=["sample"],
)
expected_init_times = expected_init_times.assign_coords(
@@ -245,7 +278,7 @@ def test_append_batch(
assert same_count_each_timestep
with xr.open_dataset(tmp_path / "monthly_mean_predictions.nc") as ds:
- assert ds.counts.sum() == n_samples * n_timesteps
+ assert ds.counts.sum() == n_initial_conditions * n_timesteps
assert np.sum(np.isnan(ds["precipitation"])) == 0
assert np.sum(np.isnan(ds["temp"])) == 0
assert np.sum(np.isnan(ds["pressure"])) == 0
@@ -253,6 +286,11 @@ def test_append_batch(
assert np.all(ds.init_time.dt.year.values >= 0)
assert np.all(ds.valid_time.dt.month.values >= 0)
+ @pytest.mark.parametrize(
+ "sample_target_data, sample_prediction_data",
+ [pytest.param((2, 3, 4, 5), (2, 3, 4, 5), id="LatLon")],
+ indirect=True,
+ )
@pytest.mark.parametrize(
["save_names"],
[
@@ -271,31 +309,31 @@ def test_append_batch_save_names(
n_samples = 2
writer = PairedDataWriter(
str(tmp_path),
- n_samples=n_samples,
+ n_initial_conditions=n_samples,
n_timesteps=4, # unused
timestep=TIMESTEP,
- metadata=sample_metadata,
+ variable_metadata=sample_metadata,
coords={"lat": np.arange(4), "lon": np.arange(5)},
enable_prediction_netcdfs=True,
enable_video_netcdfs=False,
enable_monthly_netcdfs=True,
save_names=save_names,
enable_histogram_netcdfs=True,
- prognostic_names=save_names or [],
)
start_time = (2020, 1, 1, 0, 0, 0)
end_time = (2020, 1, 1, 12, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time,
end_time=end_time,
- freq="6H",
- n_samples=n_samples,
+ freq="6h",
+ n_initial_conditions=n_samples,
)
writer.append_batch(
- sample_target_data,
- sample_prediction_data,
- start_timestep=0,
- batch_times=batch_times,
+ batch=PairedData(
+ prediction=sample_prediction_data,
+ target=sample_target_data,
+ time=batch_time,
+ ),
)
writer.flush()
dataset = Dataset(tmp_path / "autoregressive_predictions.nc", "r")
@@ -329,88 +367,99 @@ def test_append_batch_save_names(
}
)
+ @pytest.mark.parametrize(
+ "sample_target_data, sample_prediction_data",
+ [pytest.param((2, 3, 4, 5), (2, 3, 4, 5), id="LatLon")],
+ indirect=True,
+ )
def test_append_batch_data_time_mismatch(
self, sample_metadata, sample_target_data, sample_prediction_data, tmp_path
):
n_samples = 2
writer = PairedDataWriter(
str(tmp_path),
- n_samples=n_samples,
+ n_initial_conditions=n_samples,
n_timesteps=3,
timestep=TIMESTEP,
- metadata=sample_metadata,
+ variable_metadata=sample_metadata,
coords={"lat": np.arange(4), "lon": np.arange(5)},
enable_prediction_netcdfs=True,
enable_video_netcdfs=False,
enable_monthly_netcdfs=True,
save_names=None,
enable_histogram_netcdfs=True,
- prognostic_names=[],
)
start_time = (2020, 1, 1, 0, 0, 0)
end_time = (2020, 1, 1, 12, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time,
end_time=end_time,
- freq="6H",
- n_samples=n_samples + 1,
+ freq="6h",
+ n_initial_conditions=n_samples + 1,
)
with pytest.raises(ValueError):
writer.append_batch(
- sample_target_data,
- sample_prediction_data,
- start_timestep=0,
- batch_times=batch_times,
+ batch=PairedData(
+ prediction=sample_prediction_data,
+ target=sample_target_data,
+ time=batch_time,
+ ),
)
def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar):
n_samples = 2
n_timesteps = 8
coarsen_factor = 2
+ device = get_device()
prediction_data = {
- "temp": torch.rand((n_samples, n_timesteps // coarsen_factor, 4, 5)),
- "pressure": torch.rand((n_samples, n_timesteps // coarsen_factor, 4, 5)),
+ "temp": torch.rand(
+ (n_samples, n_timesteps // coarsen_factor, 4, 5), device=device
+ ),
+ "pressure": torch.rand(
+ (n_samples, n_timesteps // coarsen_factor, 4, 5), device=device
+ ),
}
writer = DataWriter(
str(tmp_path),
- n_samples=n_samples,
+ n_initial_conditions=n_samples,
n_timesteps=n_timesteps,
- metadata=sample_metadata,
+ variable_metadata=sample_metadata,
coords={"lat": np.arange(4), "lon": np.arange(5)},
timestep=TIMESTEP,
enable_prediction_netcdfs=True,
enable_monthly_netcdfs=True,
save_names=None,
- prognostic_names=["temp"],
time_coarsen=TimeCoarsenConfig(coarsen_factor),
)
start_time = (2020, 1, 1, 0, 0, 0)
end_time = (2020, 1, 1, 18, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time,
end_time=end_time,
- freq="6H",
- n_samples=n_samples,
+ freq="6h",
+ n_initial_conditions=n_samples,
calendar=calendar,
)
writer.append_batch(
- prediction_data,
- start_timestep=0,
- batch_times=batch_times,
+ batch=BatchData(
+ data=prediction_data,
+ time=batch_time,
+ ),
)
start_time_2 = (2020, 1, 2, 0, 0, 0)
end_time_2 = (2020, 1, 2, 18, 0, 0)
- batch_times = self.get_batch_times(
+ batch_time = self.get_batch_time(
start_time=start_time_2,
end_time=end_time_2,
- freq="6H",
- n_samples=n_samples,
+ freq="6h",
+ n_initial_conditions=n_samples,
calendar=calendar,
)
writer.append_batch(
- prediction_data,
- start_timestep=4,
- batch_times=batch_times,
+ batch=BatchData(
+ data=prediction_data,
+ time=batch_time,
+ ),
)
writer.flush()
@@ -428,11 +477,9 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar)
assert np.all(ds.init_time.dt.year.values >= 0)
assert np.all(ds.valid_time.dt.month.values >= 0)
- xr.open_dataset(tmp_path / "restart.nc")
-
@pytest.mark.parametrize(
- ["init_times", "batch_times", "expected"],
+ ["init_times", "batch_time", "expected"],
[
pytest.param(
np.array([cftime.DatetimeJulian(2020, 1, 1, 0, 0, 0) for _ in range(3)]),
@@ -440,7 +487,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar)
[
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 1, 0, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
).values
for _ in range(3)
@@ -462,7 +509,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar)
[
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 1, 6 * i, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
)
for i in range(3)
@@ -484,7 +531,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar)
[
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 2, 6 * i, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
)
for i in range(3)
@@ -500,50 +547,50 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar)
),
],
)
-def test_get_batch_lead_times_microseconds(init_times, batch_times, expected):
- lead_time_seconds = get_batch_lead_times_microseconds(init_times, batch_times)
+def test_get_batch_lead_times_microseconds(init_times, batch_time, expected):
+ lead_time_seconds = get_batch_lead_time_microseconds(init_times, batch_time)
assert lead_time_seconds.shape == expected.shape
np.testing.assert_equal(lead_time_seconds, expected)
-def test_get_batch_lead_times_microseconds_length_mismatch():
+def test_get_batch_lead_time_microseconds_length_mismatch():
init_times = np.array(
[cftime.DatetimeJulian(2020, 1, 1, 6 * i, 0, 0) for i in range(3)]
)
- batch_times = np.array(
+ batch_time = np.array(
[
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 2, 6 * i, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
).values
for i in range(2)
],
)
with pytest.raises(ValueError):
- get_batch_lead_times_microseconds(init_times, batch_times)
+ get_batch_lead_time_microseconds(init_times, batch_time)
-def test_get_batch_lead_times_microseconds_inconsistent_samples():
+def test_get_batch_lead_time_microseconds_inconsistent_samples():
init_times = np.array(
[cftime.DatetimeJulian(2020, 1, 1, 6, 0, 0) for _ in range(2)]
)
- batch_times = np.array(
+ batch_time = np.array(
[
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 1, 6, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
),
xr.cftime_range(
cftime.DatetimeJulian(2020, 1, 1, 12, 0, 0),
- freq="6H",
+ freq="6h",
periods=3,
),
]
)
with pytest.raises(ValueError):
- get_batch_lead_times_microseconds(init_times, batch_times)
+ get_batch_lead_time_microseconds(init_times, batch_time)
@pytest.mark.parametrize(
@@ -554,17 +601,17 @@ def test_get_batch_lead_times_microseconds_inconsistent_samples():
pytest.param(1e6, True, id="1_000_000_years_fails"),
],
)
-def test_get_batch_lead_times_microseconds_overflow(years_ahead, overflow):
+def test_get_batch_lead_time_microseconds_overflow(years_ahead, overflow):
init_times = np.array([cftime.DatetimeNoLeap(2020, 1, 1)])
- batch_times = np.array([cftime.DatetimeNoLeap(2020 + years_ahead, 1, 1)])[:, None]
+ batch_time = np.array([cftime.DatetimeNoLeap(2020 + years_ahead, 1, 1)])[:, None]
days_per_year_noleap = 365
seconds_per_day = 86400
expected_lead_time_microseconds = (
MICROSECONDS_PER_SECOND * seconds_per_day * days_per_year_noleap * years_ahead
)
if not overflow:
- lead_time = get_batch_lead_times_microseconds(init_times, batch_times)
+ lead_time = get_batch_lead_time_microseconds(init_times, batch_time)
assert lead_time.item() == expected_lead_time_microseconds
else:
with pytest.raises(OverflowError):
- get_batch_lead_times_microseconds(init_times, batch_times)
+ get_batch_lead_time_microseconds(init_times, batch_time)
diff --git a/fme/fme/ace/inference/data_writer/test_main.py b/fme/fme/ace/inference/data_writer/test_main.py
new file mode 100644
index 0000000..95b311c
--- /dev/null
+++ b/fme/fme/ace/inference/data_writer/test_main.py
@@ -0,0 +1,89 @@
+import os
+import tempfile
+
+import numpy as np
+import torch
+import xarray as xr
+
+from fme.ace.data_loading.batch_data import BatchData
+from fme.ace.inference.data_writer.main import _write
+from fme.core.dataset.data_typing import VariableMetadata
+
+
+def test_write_single_timestep():
+ n_samples = 2
+ n_lat = 4
+ n_lon = 5
+ n_time = 1
+ batch = BatchData.new_on_cpu(
+ data={"air_temperature": torch.rand((n_samples, n_time, n_lat, n_lon))},
+ time=xr.DataArray(np.random.rand(n_samples, n_time), dims=["sample", "time"]),
+ horizontal_dims=["lat", "lon"],
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ _write(
+ data=batch,
+ path=tmpdir,
+ filename="initial_condition.nc",
+ variable_metadata={
+ "air_temperature": VariableMetadata(
+ long_name="Air Temperature", units="K"
+ )
+ },
+ coords={"lat": np.arange(n_lat), "lon": np.arange(n_lon)},
+ )
+ filename = os.path.join(tmpdir, "initial_condition.nc")
+ assert os.path.exists(filename)
+ with xr.open_dataset(filename) as ds:
+ assert "air_temperature" in ds
+ assert ds.air_temperature.shape == (n_samples, n_lat, n_lon)
+ assert ds.time.shape == (n_samples,)
+ assert ds.air_temperature.dims == ("sample", "lat", "lon")
+ xr.testing.assert_allclose(ds.time, batch.time.isel(time=0))
+ np.testing.assert_allclose(
+ ds.air_temperature.values,
+ batch.data["air_temperature"].squeeze(dim=1).cpu().numpy(),
+ )
+ np.testing.assert_allclose(ds.coords["lat"].values, np.arange(n_lat))
+ np.testing.assert_allclose(ds.coords["lon"].values, np.arange(n_lon))
+ assert ds.air_temperature.attrs["long_name"] == "Air Temperature"
+ assert ds.air_temperature.attrs["units"] == "K"
+
+
+def test_write_multiple_timesteps():
+ n_samples = 2
+ n_lat = 4
+ n_lon = 5
+ n_time = 2
+ batch = BatchData.new_on_cpu(
+ data={"air_temperature": torch.rand((n_samples, n_time, n_lat, n_lon))},
+ time=xr.DataArray(np.random.rand(n_samples, n_time), dims=["sample", "time"]),
+ horizontal_dims=["lat", "lon"],
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ _write(
+ data=batch,
+ path=tmpdir,
+ filename="initial_condition.nc",
+ variable_metadata={
+ "air_temperature": VariableMetadata(
+ long_name="Air Temperature", units="K"
+ )
+ },
+ coords={"lat": np.arange(n_lat), "lon": np.arange(n_lon)},
+ )
+ filename = os.path.join(tmpdir, "initial_condition.nc")
+ assert os.path.exists(filename)
+ with xr.open_dataset(filename) as ds:
+ assert "air_temperature" in ds
+ assert ds.air_temperature.shape == (n_samples, n_time, n_lat, n_lon)
+ assert ds.time.shape == (n_samples, n_time)
+ assert ds.air_temperature.dims == ("sample", "time", "lat", "lon")
+ np.testing.assert_allclose(ds.time.values, batch.time.values)
+ np.testing.assert_allclose(
+ ds.air_temperature.values, batch.data["air_temperature"].cpu().numpy()
+ )
+ np.testing.assert_allclose(ds.coords["lat"].values, np.arange(n_lat))
+ np.testing.assert_allclose(ds.coords["lon"].values, np.arange(n_lon))
+ assert ds.air_temperature.attrs["long_name"] == "Air Temperature"
+ assert ds.air_temperature.attrs["units"] == "K"
diff --git a/fme/fme/ace/inference/data_writer/test_monthly.py b/fme/fme/ace/inference/data_writer/test_monthly.py
index 4dac37b..48eb004 100644
--- a/fme/fme/ace/inference/data_writer/test_monthly.py
+++ b/fme/fme/ace/inference/data_writer/test_monthly.py
@@ -13,7 +13,7 @@
get_days_since_reference,
months_for_timesteps,
)
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.core.dataset.data_typing import VariableMetadata
TIMESTEP = datetime.timedelta(hours=6)
@@ -38,7 +38,7 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int):
n_samples=n_samples,
n_months=24,
save_names=None,
- metadata={"x": VariableMetadata(units="m", long_name="x_name")},
+ variable_metadata={"x": VariableMetadata(units="m", long_name="x_name")},
coords={},
)
month_values = []
@@ -51,7 +51,7 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int):
month_data = {"x": x_window}
initial_time = cftime.DatetimeProlepticGregorian(year, month, 1, 0, 0, 0)
for i_write in range(n_writes):
- times = xr.DataArray(
+ time = xr.DataArray(
[
[
initial_time + datetime.timedelta(hours=6 * i_write)
@@ -61,10 +61,8 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int):
],
dims=["sample", "time"],
)
- assert times.shape == (n_samples, window_size)
- writer.append_batch(
- data=month_data, start_timestep=0, batch_times=times
- )
+ assert time.shape == (n_samples, window_size)
+ writer.append_batch(data=month_data, start_timestep=0, batch_time=time)
writer.flush()
written = xr.open_dataset(str(tmpdir / "monthly_mean_predictions.nc"))
assert written["x"].shape == (n_samples, 24, n_lat, n_lon)
@@ -93,21 +91,46 @@ def test_months_for_timesteps(n_timesteps: int, min_expected: int):
assert months_for_timesteps(n_timesteps, TIMESTEP) >= min_expected
-def test_get_days_since_reference():
- years = np.array([2020, 2021])
- months = np.array([0, 1]) # expects zero-indexed months
- reference_date = cftime.DatetimeProlepticGregorian(2020, 1, 1)
+@pytest.mark.parametrize("num_years", [2, 500])
+@pytest.mark.parametrize("calendar", ["proleptic_gregorian", "noleap"])
+def test_get_days_since_reference(num_years, calendar):
+ first_year = 2020
+ final_year = first_year + num_years - 1
+ years = np.array([i for i in range(first_year, final_year + 1)])
+ months = np.zeros((num_years,), dtype=int)
+ # For last year set month to 1
+ months[-1] = 1
+ if calendar == "proleptic_gregorian":
+ reference_date = cftime.DatetimeProlepticGregorian(2020, 1, 1)
+ else:
+ reference_date = cftime.DatetimeNoLeap(2020, 1, 1)
n_months = 3
- calendar = "proleptic_gregorian"
days = get_days_since_reference(years, months, reference_date, n_months, calendar)
- assert days.shape == (2, 3)
- assert days[0, 0] == 0
- assert days[0, 1] == 31
- assert days[0, 2] == 31 + 29
- # 2020 is a leap year
- assert days[1, 0] == 366 + 31
- assert days[1, 1] == 366 + 31 + 28
- assert days[1, 2] == 366 + 31 + 28 + 31
+ assert days.shape == (num_years, 3)
+ # 2020 is a leap year in proleptic_gregorian
+ if calendar == "proleptic_gregorian":
+ assert days[0, 0] == 0
+ assert days[0, 1] == 31
+ assert days[0, 2] == 31 + 29
+ if num_years == 2:
+ assert days[1, 0] == 366 + 31
+ assert days[1, 1] == 366 + 31 + 28
+ assert days[1, 2] == 366 + 31 + 28 + 31
+ if num_years == 500:
+ # 121 is number of leap days
+ assert days[499, 0] == 182135 + 121 + 31
+ assert days[499, 1] == 182135 + 121 + 31 + 28
+ if calendar == "noleap":
+ assert days[0, 0] == 0
+ assert days[0, 1] == 31
+ assert days[0, 2] == 31 + 28
+ if num_years == 2:
+ assert days[1, 0] == 365 + 31
+ assert days[1, 1] == 365 + 31 + 28
+ assert days[1, 2] == 365 + 31 + 28 + 31
+ if num_years == 500:
+ assert days[499, 0] == 182135 + 31
+ assert days[499, 1] == 182135 + 31 + 28
@pytest.mark.parametrize(
diff --git a/fme/fme/ace/inference/data_writer/test_restart.py b/fme/fme/ace/inference/data_writer/test_restart.py
deleted file mode 100644
index a858830..0000000
--- a/fme/fme/ace/inference/data_writer/test_restart.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import numpy as np
-import torch
-import xarray as xr
-
-from fme.ace.inference.data_writer.restart import RestartWriter
-from fme.core.data_loading.data_typing import VariableMetadata
-
-
-def test_restart_saves_last_step(tmpdir):
- """
- If multiple steps are configured as restart steps, the last one should be saved.
- """
- n_sample: int = 3
- n_time: int = 2
- n_lat = 10
- n_lon = 20
- lat = np.linspace(-90, 90, n_lat)
- lon = np.linspace(-180, 180, n_lon)
- writer = RestartWriter(
- path=tmpdir,
- is_restart_step=lambda i: True,
- prognostic_names=["a", "b"],
- metadata={"a": VariableMetadata(long_name="var_a", units="m")},
- coords={"lon": lon, "lat": lat},
- )
- data = {
- "a": torch.randn(n_sample, n_time, n_lat, n_lon),
- "b": torch.randn(n_sample, n_time, n_lat, n_lon),
- }
- batch_times = xr.DataArray(
- data=np.random.uniform(size=(n_sample, n_time)),
- dims=(
- "sample",
- "time",
- ),
- )
- writer.append_batch(data, 0, batch_times)
- ds = xr.open_dataset(str(tmpdir / "restart.nc"))
- np.testing.assert_allclose(ds.a.values, data["a"][:, -1].cpu().numpy())
- np.testing.assert_allclose(ds.b.values, data["b"][:, -1].cpu().numpy())
- np.testing.assert_allclose(ds.time.values, batch_times[:, -1].values)
- assert len(ds.b.attrs) == 0
- assert len(ds.a.attrs) == 2
- assert ds.a.attrs["long_name"] == "var_a"
- assert ds.a.attrs["units"] == "m"
- np.testing.assert_allclose(ds.lon.values, lon)
- np.testing.assert_allclose(ds.lat.values, lat)
- assert ds.attrs["timestep"] == 1
-
-
-def test_restart_saves_configured_step(tmpdir):
- """
- If a specific step is configured as a restart step, that step should be saved.
- """
- n_sample: int = 3
- i_time_start = 4
- i_time_target = 6
- n_time: int = 4
- n_lat = 10
- n_lon = 20
- lat = np.linspace(-90, 90, n_lat)
- lon = np.linspace(-180, 180, n_lon)
- writer = RestartWriter(
- path=tmpdir,
- is_restart_step=lambda i: i == i_time_target,
- prognostic_names=["a", "b"],
- metadata={"a": VariableMetadata(long_name="var_a", units="m")},
- coords={"lon": lon, "lat": lat},
- )
- data = {
- "a": torch.randn(n_sample, n_time, n_lat, n_lon),
- "b": torch.randn(n_sample, n_time, n_lat, n_lon),
- }
- batch_times = xr.DataArray(
- data=np.random.uniform(size=(n_sample, n_time)),
- dims=(
- "sample",
- "time",
- ),
- )
- writer.append_batch(data, i_time_start, batch_times)
- ds = xr.open_dataset(str(tmpdir / "restart.nc"))
- np.testing.assert_allclose(
- ds.a.values, data["a"][:, i_time_target - i_time_start].cpu().numpy()
- )
- np.testing.assert_allclose(
- ds.b.values, data["b"][:, i_time_target - i_time_start].cpu().numpy()
- )
- np.testing.assert_allclose(
- ds.time.values, batch_times[:, i_time_target - i_time_start].values
- )
- assert len(ds.b.attrs) == 0
- assert len(ds.a.attrs) == 2
- assert ds.a.attrs["long_name"] == "var_a"
- assert ds.a.attrs["units"] == "m"
- np.testing.assert_allclose(ds.lon.values, lon)
- np.testing.assert_allclose(ds.lat.values, lat)
- assert ds.attrs["timestep"] == i_time_target
-
-
-def test_restart_does_not_save(tmpdir):
- """
- If no step is configured to save as restart, no restart should be saved.
- """
- n_sample: int = 3
- n_time: int = 2
- n_lat = 10
- n_lon = 20
- lat = np.linspace(-90, 90, n_lat)
- lon = np.linspace(-180, 180, n_lon)
- writer = RestartWriter(
- path=tmpdir,
- is_restart_step=lambda i: False,
- prognostic_names=["a", "b"],
- metadata={"a": VariableMetadata(long_name="var_a", units="m")},
- coords={"lon": lon, "lat": lat},
- )
- data = {
- "a": torch.randn(n_sample, n_time, n_lat, n_lon),
- "b": torch.randn(n_sample, n_time, n_lat, n_lon),
- }
- batch_times = xr.DataArray(
- data=np.random.uniform(size=(n_sample, n_time)),
- dims=(
- "sample",
- "time",
- ),
- )
- writer.append_batch(data, 0, batch_times)
- assert not (tmpdir / "restart.nc").exists()
diff --git a/fme/fme/ace/inference/data_writer/test_time_coarsen.py b/fme/fme/ace/inference/data_writer/test_time_coarsen.py
index 3902f03..2cb09cb 100644
--- a/fme/fme/ace/inference/data_writer/test_time_coarsen.py
+++ b/fme/fme/ace/inference/data_writer/test_time_coarsen.py
@@ -21,11 +21,11 @@ def get_windowed_batch(dim_sizes: Sequence[int], start_time: Sequence[int]):
.movedim(3, 1)
)
target = {VARNAME: data}
- times = get_batch_times(n_timesteps, start_time, n_samples=n_samples)
- return target, times
+ time = get_batch_time(n_timesteps, start_time, n_samples=n_samples)
+ return target, time
-def get_batch_times(
+def get_batch_time(
n_timesteps: int,
start_time: Sequence[int],
n_samples: int = 2,
@@ -49,7 +49,7 @@ def _get_time_array(
hours=6 * start_n_offset
)
time_index = xr.cftime_range(
- start=start_time, periods=n_timesteps, freq=f"{freq_hrs}H", calendar="julian"
+ start=start_time, periods=n_timesteps, freq=f"{freq_hrs}h", calendar="julian"
)
return xr.DataArray(data=time_index, dims=["time"]).drop_vars(["time"])
@@ -59,7 +59,7 @@ def _get_time_array(
"coarsen_factor",
"start_timestep",
"expected_coarsened_data",
- "expected_coarsened_times",
+ "expected_coarsened_time",
"expected_coarsened_start_timestep",
],
[
@@ -67,7 +67,7 @@ def _get_time_array(
1,
0,
[0.0, 1.0, 2.0, 3.0],
- get_batch_times(start_time=(2020, 1, 1, 0, 0, 0), n_timesteps=4),
+ get_batch_time(start_time=(2020, 1, 1, 0, 0, 0), n_timesteps=4),
0,
id="coarsen_factor_1",
),
@@ -75,7 +75,7 @@ def _get_time_array(
2,
0,
[0.5, 2.5],
- get_batch_times(
+ get_batch_time(
start_time=(2020, 1, 1, 3, 0, 0), n_timesteps=2, freq_hrs=12
),
0,
@@ -85,7 +85,7 @@ def _get_time_array(
2,
3,
[0.5, 2.5],
- get_batch_times(
+ get_batch_time(
start_time=(2020, 1, 1, 3, 0, 0), n_timesteps=2, freq_hrs=12
),
1,
@@ -95,7 +95,7 @@ def _get_time_array(
4,
0,
[1.5],
- get_batch_times(
+ get_batch_time(
start_time=(2020, 1, 1, 9, 0, 0), n_timesteps=1, freq_hrs=24
),
0,
@@ -107,41 +107,45 @@ def test_time_coarsen(
coarsen_factor: int,
start_timestep: int,
expected_coarsened_data: Sequence[float],
- expected_coarsened_times: Sequence[cftime.DatetimeJulian],
+ expected_coarsened_time: Sequence[cftime.DatetimeJulian],
expected_coarsened_start_timestep: int,
dim_sizes: Sequence[int] = DIM_SIZES,
):
- target, times = get_windowed_batch(
+ target, time = get_windowed_batch(
dim_sizes=dim_sizes, start_time=(2020, 1, 1, 0, 0, 0)
)
(
target_coarsened,
coarsened_start_timestep,
- times_coarsened,
+ time_coarsened,
) = coarsen_batch(
data=target,
start_timestep=start_timestep,
- batch_times=times,
+ batch_time=time,
coarsen_factor=coarsen_factor,
)
# check the coarsened data time dim size
assert target_coarsened[VARNAME].size(dim=1) == len(
expected_coarsened_data
), "target coarsened time dim"
- assert times_coarsened.sizes["time"] == len(
+ assert time_coarsened.sizes["time"] == len(
expected_coarsened_data
- ), "times coarsened time dim"
+ ), "time coarsened time dim"
# check the coarsened data values
n_samples, _, n_lat, n_lon = dim_sizes
- torch.testing.assert_close(
- target_coarsened[VARNAME],
- torch.tensor(expected_coarsened_data, dtype=torch.float64)
- .repeat(n_samples, n_lat, n_lon, 1)
- .movedim(3, 1),
- ), "target coarsened value"
+ (
+ torch.testing.assert_close(
+ target_coarsened[VARNAME],
+ torch.tensor(expected_coarsened_data, dtype=torch.float64)
+ .repeat(n_samples, n_lat, n_lon, 1)
+ .movedim(3, 1),
+ ),
+ "target coarsened value",
+ )
# check the coarsened start timestep
assert coarsened_start_timestep == expected_coarsened_start_timestep
# check the coarsened data time coordinate values
- xr.testing.assert_allclose(
- times_coarsened, expected_coarsened_times
- ), "times initial condition value"
+ (
+ xr.testing.assert_allclose(time_coarsened, expected_coarsened_time),
+ "time initial condition value",
+ )
diff --git a/fme/fme/ace/inference/data_writer/time_coarsen.py b/fme/fme/ace/inference/data_writer/time_coarsen.py
index 21d294a..a677b74 100644
--- a/fme/fme/ace/inference/data_writer/time_coarsen.py
+++ b/fme/fme/ace/inference/data_writer/time_coarsen.py
@@ -14,7 +14,7 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
pass
@@ -27,7 +27,7 @@ def append_batch(
self,
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
pass
@@ -66,7 +66,7 @@ def build(self, data_writer: _DataWriter) -> "TimeCoarsen":
)
def n_coarsened_timesteps(self, n_timesteps: int) -> int:
- """Assumes initial condition is NOT in n_timesteps"""
+ """Assumes initial condition is NOT in n_timesteps."""
return (n_timesteps) // self.coarsen_factor
@@ -86,13 +86,13 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
(target_coarsened, start_timestep, batch_times_coarsened) = coarsen_batch(
- target, start_timestep, batch_times, self._coarsen_factor
+ target, start_timestep, batch_time, self._coarsen_factor
)
(prediction_coarsened, _, _) = coarsen_batch(
- prediction, start_timestep, batch_times, self._coarsen_factor
+ prediction, start_timestep, batch_time, self._coarsen_factor
)
self._data_writer.append_batch(
target_coarsened,
@@ -120,10 +120,10 @@ def append_batch(
self,
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
(data_coarsened, start_timestep, batch_times_coarsened) = coarsen_batch(
- data, start_timestep, batch_times, self._coarsen_factor
+ data, start_timestep, batch_time, self._coarsen_factor
)
self._data_writer.append_batch(
data_coarsened,
@@ -138,13 +138,13 @@ def flush(self):
def coarsen_batch(
data: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
coarsen_factor: int,
) -> Tuple[Dict[str, torch.Tensor], int, xr.DataArray]:
data_coarsened = _coarsen_tensor_dict(data, coarsen_factor)
start_timestep = start_timestep // coarsen_factor
- batch_times_coarsened = batch_times.coarsen({TIME_DIM_NAME: coarsen_factor}).mean()
- return data_coarsened, start_timestep, batch_times_coarsened
+ batch_time_coarsened = batch_time.coarsen({TIME_DIM_NAME: coarsen_factor}).mean()
+ return data_coarsened, start_timestep, batch_time_coarsened
def _coarsen_tensor_dict(
diff --git a/fme/fme/ace/inference/data_writer/utils.py b/fme/fme/ace/inference/data_writer/utils.py
index 6fd7e4d..52bfb7d 100644
--- a/fme/fme/ace/inference/data_writer/utils.py
+++ b/fme/fme/ace/inference/data_writer/utils.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from typing import Iterable, Optional, Set, TypeVar
T = TypeVar("T")
@@ -17,3 +18,21 @@ def get_all_names(
return variables
else:
return variables.intersection(set(allowlist))
+
+
+@dataclass
+class DimInfo:
+ name: str
+ index: int
+
+
+DIM_INFO_LATLON = [
+ DimInfo(name="lat", index=-2),
+ DimInfo(name="lon", index=-1),
+]
+
+DIM_INFO_HEALPIX = [
+ DimInfo(name="face", index=-3),
+ DimInfo(name="height", index=-2),
+ DimInfo(name="width", index=-1),
+]
diff --git a/fme/fme/ace/inference/data_writer/video.py b/fme/fme/ace/inference/data_writer/video.py
index fd95d01..e55ad80 100644
--- a/fme/fme/ace/inference/data_writer/video.py
+++ b/fme/fme/ace/inference/data_writer/video.py
@@ -5,8 +5,8 @@
import torch
import xarray as xr
-from fme.core.aggregator.inference.video import VideoAggregator
-from fme.core.data_loading.data_typing import VariableMetadata
+from fme.ace.aggregator.inference.video import VideoAggregator
+from fme.core.dataset.data_typing import VariableMetadata
class PairedVideoDataWriter:
@@ -18,25 +18,27 @@ def __init__(
self,
path: str,
n_timesteps: int,
- metadata: Mapping[str, VariableMetadata],
+ variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
):
"""
Args:
- filename: Path to write netCDF file(s).
+ path: Directory within which to write the file.
n_samples: Number of samples to write to the file.
n_timesteps: Number of timesteps to write to the file.
- metadata: Metadata for each variable to be written to the file.
+ variable_metadata: Metadata for each variable to be written to the file.
coords: Coordinate data to be written to the file.
"""
self.path = path
self._metrics_filename = str(
Path(path) / "reduced_autoregressive_predictions.nc"
)
- self.metadata = metadata
+ self.variable_metadata = variable_metadata
self.coords = coords
self._video = VideoAggregator(
- n_timesteps=n_timesteps, enable_extended_videos=True, metadata=metadata
+ n_timesteps=n_timesteps,
+ enable_extended_videos=True,
+ variable_metadata=variable_metadata,
)
def append_batch(
@@ -44,7 +46,7 @@ def append_batch(
target: Dict[str, torch.Tensor],
prediction: Dict[str, torch.Tensor],
start_timestep: int,
- batch_times: xr.DataArray,
+ batch_time: xr.DataArray,
):
"""
Append a batch of data to the file.
@@ -53,10 +55,9 @@ def append_batch(
target: Target data.
prediction: Prediction data.
start_timestep: Timestep at which to start writing.
- batch_times: Time coordinates for each sample in the batch. Unused.
+ batch_time: Time coordinate for each sample in the batch. Unused.
"""
self._video.record_batch(
- loss=np.nan,
target_data=target,
gen_data=prediction,
i_time_start=start_timestep,
diff --git a/fme/fme/ace/inference/derived_variables.py b/fme/fme/ace/inference/derived_variables.py
index c5a673d..429bcde 100644
--- a/fme/fme/ace/inference/derived_variables.py
+++ b/fme/fme/ace/inference/derived_variables.py
@@ -1,86 +1,76 @@
-import dataclasses
import datetime
import logging
-from typing import Callable, Dict, List, MutableMapping, Optional
+from typing import Callable, Dict, MutableMapping, Optional
import torch
from fme.core import metrics
from fme.core.climate_data import ClimateData
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.device import get_device
-from fme.core.stepper import SteppedData
+DerivedVariableFunc = Callable[
+ [ClimateData, HybridSigmaPressureCoordinate, datetime.timedelta], torch.Tensor
+]
-@dataclasses.dataclass
-class DerivedVariableRegistryEntry:
- func: Callable[[ClimateData, SigmaCoordinates, datetime.timedelta], torch.Tensor]
- required_inputs: Optional[List[str]] = None
+_DERIVED_VARIABLE_REGISTRY: MutableMapping[str, DerivedVariableFunc] = {}
-_DERIVED_VARIABLE_REGISTRY: MutableMapping[str, DerivedVariableRegistryEntry] = {}
+def register(func: DerivedVariableFunc):
+ label = func.__name__
+ if label in _DERIVED_VARIABLE_REGISTRY:
+ raise ValueError(f"Function {label} has already been added to registry.")
+ _DERIVED_VARIABLE_REGISTRY[label] = func
+ return func
-def register(
- required_inputs: Optional[List[str]] = None,
-):
- """Decorator for registering a function that computes a derived variable."""
-
- def decorator(
- func: Callable[
- [ClimateData, SigmaCoordinates, datetime.timedelta], torch.Tensor
- ]
- ):
- label = func.__name__
- if label in _DERIVED_VARIABLE_REGISTRY:
- raise ValueError(f"Function {label} has already been added to registry.")
- _DERIVED_VARIABLE_REGISTRY[label] = DerivedVariableRegistryEntry(
- func=func, required_inputs=required_inputs
- )
- return func
-
- return decorator
-
-@register()
+@register
def surface_pressure_due_to_dry_air(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
) -> torch.Tensor:
return metrics.surface_pressure_due_to_dry_air(
data.specific_total_water,
data.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
+ vertical_coordinate,
)
-@register()
+@register
+def surface_pressure_due_to_dry_air_absolute_tendency(
+ data: ClimateData,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+) -> torch.Tensor:
+ ps_dry = surface_pressure_due_to_dry_air(data, vertical_coordinate, timestep)
+ abs_ps_dry_tendency = torch.zeros_like(ps_dry)
+ abs_ps_dry_tendency[:, 1:] = torch.diff(ps_dry, n=1, dim=1).abs()
+ return abs_ps_dry_tendency
+
+
+@register
def total_water_path(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
) -> torch.Tensor:
- return metrics.vertical_integral(
+ return vertical_coordinate.vertical_integral(
data.specific_total_water,
data.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
)
-@register()
+@register
def total_water_path_budget_residual(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
- total_water_path = metrics.vertical_integral(
+ total_water_path = vertical_coordinate.vertical_integral(
data.specific_total_water,
data.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
)
twp_total_tendency = (total_water_path[:, 1:] - total_water_path[:, :-1]) / (
timestep.total_seconds()
@@ -95,23 +85,23 @@ def total_water_path_budget_residual(
return twp_budget_residual
-@register(
- required_inputs=[
- "DSWRFtoa",
- ]
-)
+@register
def net_energy_flux_toa_into_atmosphere(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
- return data.data["DSWRFtoa"] - data.data["USWRFtoa"] - data.data["ULWRFtoa"]
+ return (
+ data.toa_down_sw_radiative_flux
+ - data.toa_up_sw_radiative_flux
+ - data.toa_up_lw_radiative_flux
+ )
-@register()
+@register
def net_energy_flux_sfc_into_atmosphere(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
# property is defined as positive into surface, but want to compare to
@@ -119,38 +109,36 @@ def net_energy_flux_sfc_into_atmosphere(
return -data.net_surface_energy_flux_without_frozen_precip
-@register(required_inputs=["DSWRFtoa"])
+@register
def net_energy_flux_into_atmospheric_column(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
return net_energy_flux_sfc_into_atmosphere(
- data, sigma_coordinates, timestep
- ) + net_energy_flux_toa_into_atmosphere(data, sigma_coordinates, timestep)
+ data, vertical_coordinate, timestep
+ ) + net_energy_flux_toa_into_atmosphere(data, vertical_coordinate, timestep)
-@register(required_inputs=["HGTsfc"])
+@register
def column_moist_static_energy(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
- return metrics.vertical_integral(
- data.moist_static_energy(sigma_coordinates),
+ return vertical_coordinate.vertical_integral(
+ data.moist_static_energy(vertical_coordinate),
data.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
)
-@register(required_inputs=["HGTsfc"])
+@register
def column_moist_static_energy_tendency(
data: ClimateData,
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
- mse = column_moist_static_energy(data, sigma_coordinates, timestep)
+ mse = column_moist_static_energy(data, vertical_coordinate, timestep)
diff = torch.diff(mse, n=1, dim=1)
# Only the very first timestep in series is filled with nan; subsequent batches
# drop the first step as it's the initial condition.
@@ -164,32 +152,31 @@ def column_moist_static_energy_tendency(
def _compute_derived_variable(
data: Dict[str, torch.Tensor],
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
label: str,
- derived_variable: DerivedVariableRegistryEntry,
+ derived_variable_func: DerivedVariableFunc,
forcing_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Computes a derived variable and adds it to the given data.
- If the required input data is not available,
- no change will be made to the data.
+ The derived variable name must not already exist in the data.
+
+ If any required input data are not available,
+ the derived variable will not be computed.
Args:
data: dictionary of data add the derived variable to.
- sigma_coordinates: the vertical coordinate.
+ vertical_coordinate: the vertical coordinate.
timestep: Timestep of the model.
label: the name of the derived variable.
- derived_variable: class indicating required names and function to compute.
+ derived_variable_func: derived variable function to compute.
forcing_data: optional dictionary of forcing data needed for some derived
variables. If necessary forcing inputs are missing, the derived
variable will not be computed.
Returns:
- A new SteppedData instance with the derived variable added.
-
- Note:
- Derived variables are only computed for the denormalized data in stepped.
+ A new data dictionary with the derived variable added.
"""
if label in data:
raise ValueError(
@@ -197,15 +184,15 @@ def _compute_derived_variable(
"to overwrite existing variables with derived variables."
)
new_data = data.copy()
+ if forcing_data is not None:
+ for key, value in forcing_data.items():
+ if key not in data:
+ data[key] = value
+
climate_data = ClimateData(data)
try:
- if forcing_data and derived_variable.required_inputs:
- for v in derived_variable.required_inputs:
- if v not in forcing_data:
- raise KeyError(v)
- climate_data.data.update({v: forcing_data[v]})
- output = derived_variable.func(climate_data, sigma_coordinates, timestep)
+ output = derived_variable_func(climate_data, vertical_coordinate, timestep)
except KeyError as key_error:
logging.debug(f"Could not compute {label} because {key_error} is missing")
else: # if no exception was raised
@@ -215,43 +202,18 @@ def _compute_derived_variable(
def compute_derived_quantities(
data: Dict[str, torch.Tensor],
- sigma_coordinates: SigmaCoordinates,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
- registry: MutableMapping[
- str, DerivedVariableRegistryEntry
- ] = _DERIVED_VARIABLE_REGISTRY,
forcing_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Computes all derived quantities from the given data."""
- for label, derived_variable in registry.items():
+ for label, func in _DERIVED_VARIABLE_REGISTRY.items():
data = _compute_derived_variable(
data,
- sigma_coordinates,
+ vertical_coordinate,
timestep,
label,
- derived_variable,
+ func,
forcing_data=forcing_data,
)
return data
-
-
-def compute_stepped_derived_quantities(
- stepped: SteppedData,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- registry: MutableMapping[
- str, DerivedVariableRegistryEntry
- ] = _DERIVED_VARIABLE_REGISTRY,
- forcing_data: Optional[Dict[str, torch.Tensor]] = None,
-) -> SteppedData:
- stepped.gen_data = compute_derived_quantities(
- stepped.gen_data, sigma_coordinates, timestep, registry, forcing_data
- )
- stepped.target_data = compute_derived_quantities(
- stepped.target_data,
- sigma_coordinates,
- timestep,
- registry,
- forcing_data,
- )
- return stepped
diff --git a/fme/fme/ace/inference/evaluator.py b/fme/fme/ace/inference/evaluator.py
index 1ce48ce..f3abbbb 100755
--- a/fme/fme/ace/inference/evaluator.py
+++ b/fme/fme/ace/inference/evaluator.py
@@ -2,9 +2,7 @@
import dataclasses
import logging
import os
-import time
-from pathlib import Path
-from typing import Optional, Sequence
+from typing import Callable, Optional
import dacite
import torch
@@ -12,44 +10,53 @@
import fme
import fme.core.logging_utils as logging_utils
+from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig
+from fme.ace.data_loading.batch_data import BatchData, InferenceGriddedData
+from fme.ace.data_loading.getters import get_inference_data
+from fme.ace.data_loading.inference import InferenceDataLoaderConfig
from fme.ace.inference.data_writer import DataWriterConfig, PairedDataWriter
from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig
-from fme.ace.inference.loop import run_dataset_comparison, run_inference_evaluator
-from fme.core import SingleModuleStepper
-from fme.core.aggregator.inference import InferenceEvaluatorAggregatorConfig
-from fme.core.data_loading.data_typing import GriddedData, SigmaCoordinates
-from fme.core.data_loading.getters import get_inference_data
-from fme.core.data_loading.inference import InferenceDataLoaderConfig
+from fme.ace.inference.loop import (
+ DeriverABC,
+ run_dataset_comparison,
+ write_reduced_metrics,
+)
+from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig
from fme.core.dicts import to_flat_dict
+from fme.core.generics.inference import get_record_to_wandb, run_inference
from fme.core.logging_utils import LoggingConfig
from fme.core.ocean import OceanConfig
-from fme.core.stepper import SingleModuleStepperConfig
-from fme.core.wandb import WandB
+from fme.core.timing import GlobalTimer
+from fme.core.typing_ import TensorDict, TensorMapping
-def load_stepper_config(checkpoint_file: str) -> SingleModuleStepperConfig:
+def load_stepper_config(
+ checkpoint_file: str, ocean_config: Optional[OceanConfig]
+) -> SingleModuleStepperConfig:
checkpoint = torch.load(checkpoint_file, map_location=fme.get_device())
- return SingleModuleStepperConfig.from_state(checkpoint["stepper"]["config"])
+ config = SingleModuleStepperConfig.from_state(checkpoint["stepper"]["config"])
+ if ocean_config is not None:
+ logging.info(
+ "Overriding training ocean configuration with the inference ocean config."
+ )
+ config.ocean = ocean_config
+ return config
def load_stepper(
checkpoint_file: str,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
ocean_config: Optional[OceanConfig] = None,
) -> SingleModuleStepper:
checkpoint = torch.load(checkpoint_file, map_location=fme.get_device())
- stepper = SingleModuleStepper.from_state(
- checkpoint["stepper"], area=area, sigma_coordinates=sigma_coordinates
- )
+ stepper = SingleModuleStepper.from_state(checkpoint["stepper"])
if ocean_config is not None:
logging.info(
"Overriding training ocean configuration with the inference ocean config."
)
new_ocean = ocean_config.build(
- stepper.in_packer.names, stepper.out_packer.names, stepper.timestep
+ stepper.in_names, stepper.out_names, stepper.timestep
)
- stepper.ocean = new_ocean
+ stepper.replace_ocean(new_ocean)
return stepper
@@ -76,7 +83,7 @@ class InferenceEvaluatorConfig:
"""
Configuration for running inference including comparison to reference data.
- Attributes:
+ Parameters:
experiment_dir: Directory to save results to.
n_forward_steps: Number of steps to run the model forward for.
checkpoint_path: Path to stepper checkpoint to load.
@@ -129,41 +136,22 @@ def configure_wandb(self, env_vars: Optional[dict] = None, **kwargs):
def clean_wandb(self):
self.logging.clean_wandb(self.experiment_dir)
- def configure_gcs(self):
- self.logging.configure_gcs()
-
- def load_stepper(
- self, area: torch.Tensor, sigma_coordinates: SigmaCoordinates
- ) -> SingleModuleStepper:
- """
- Args:
- area: A tensor of shape (n_lat, n_lon) containing the area of
- each grid cell.
- sigma_coordinates: The sigma coordinates of the model.
- """
+ def load_stepper(self) -> SingleModuleStepper:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
- stepper = load_stepper(
- self.checkpoint_path,
- area=area,
- sigma_coordinates=sigma_coordinates,
- ocean_config=self.ocean,
- )
+ stepper = load_stepper(self.checkpoint_path, ocean_config=self.ocean)
return stepper
def load_stepper_config(self) -> SingleModuleStepperConfig:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
- return load_stepper_config(self.checkpoint_path)
+ return load_stepper_config(self.checkpoint_path, ocean_config=self.ocean)
- def get_data_writer(
- self, data: GriddedData, prognostic_names: Sequence[str]
- ) -> PairedDataWriter:
+ def get_data_writer(self, data: InferenceGriddedData) -> PairedDataWriter:
return self.data_writer.build_paired(
experiment_dir=self.experiment_dir,
- n_samples=self.loader.n_samples,
+ n_initial_conditions=self.loader.n_initial_conditions,
n_timesteps=self.n_forward_steps,
timestep=data.timestep,
- prognostic_names=prognostic_names,
- metadata=data.metadata,
+ variable_metadata=data.variable_metadata,
coords=data.coords,
)
@@ -180,10 +168,45 @@ def main(yaml_config: str):
os.makedirs(config.experiment_dir, exist_ok=True)
with open(os.path.join(config.experiment_dir, "config.yaml"), "w") as f:
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
- return run_evaluator_from_config(config)
+ with GlobalTimer():
+ return run_evaluator_from_config(config)
+
+
+class _Deriver(DeriverABC):
+ """
+ DeriverABC implementation for dataset comparison.
+ """
+
+ def __init__(
+ self,
+ n_ic_timesteps: int,
+ derive_func: Callable[[TensorMapping, TensorMapping], TensorDict],
+ ):
+ self._n_ic_timesteps = n_ic_timesteps
+ self._derive_func = derive_func
+
+ @property
+ def n_ic_timesteps(self) -> int:
+ return self._n_ic_timesteps
+
+ def get_forward_data(
+ self, data: BatchData, compute_derived_variables: bool = False
+ ) -> BatchData:
+ if compute_derived_variables:
+ timer = GlobalTimer.get_instance()
+ with timer.context("compute_derived_variables"):
+ data = data.compute_derived_variables(
+ derive_func=self._derive_func,
+ forcing_data=data,
+ )
+ return data.remove_initial_condition(self._n_ic_timesteps)
def run_evaluator_from_config(config: InferenceEvaluatorConfig):
+ timer = GlobalTimer.get_instance()
+ timer.start_outer("inference")
+ timer.start("initialization")
+
if not os.path.isdir(config.experiment_dir):
os.makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
@@ -196,22 +219,22 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):
logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")
- start_time = time.time()
stepper_config = config.load_stepper_config()
logging.info("Loading inference data")
- data_requirements = stepper_config.get_data_requirements(
- n_forward_steps=config.n_forward_steps
+ window_requirements = stepper_config.get_evaluation_window_data_requirements(
+ n_forward_steps=config.forward_steps_in_memory
+ )
+ initial_condition_requirements = (
+ stepper_config.get_prognostic_state_data_requirements()
)
data = get_inference_data(
- config.loader,
- config.forward_steps_in_memory,
- data_requirements,
+ config=config.loader,
+ total_forward_steps=config.n_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
)
- stepper = config.load_stepper(
- data.area_weights.to(fme.get_device()),
- sigma_coordinates=data.sigma_coordinates.to(fme.get_device()),
- )
+ stepper = config.load_stepper()
if stepper.timestep != data.timestep:
raise ValueError(
f"Timestep of the loaded stepper, {stepper.timestep}, does not "
@@ -220,81 +243,86 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):
aggregator_config: InferenceEvaluatorAggregatorConfig = config.aggregator
for batch in data.loader:
- initial_times = batch.times.isel(time=0)
+ initial_time = batch.time.isel(time=0)
break
aggregator = aggregator_config.build(
- area_weights=data.area_weights.to(fme.get_device()),
- sigma_coordinates=data.sigma_coordinates,
+ vertical_coordinate=data.vertical_coordinate,
+ horizontal_coordinates=data.horizontal_coordinates,
timestep=data.timestep,
record_step_20=config.n_forward_steps >= 20,
- n_timesteps=config.n_forward_steps + 1,
- metadata=data.metadata,
- data_grid=data.grid,
- initial_times=initial_times,
+ n_timesteps=config.n_forward_steps + stepper_config.n_ic_timesteps,
+ variable_metadata=data.variable_metadata,
+ initial_time=initial_time,
+ channel_mean_names=stepper.out_names,
+ normalize=stepper.normalizer.normalize,
)
- writer = config.get_data_writer(data, stepper.prognostic_names)
+ writer = config.get_data_writer(data)
+ timer.stop()
logging.info("Starting inference")
+ record_logs = get_record_to_wandb(label="inference")
if config.prediction_loader is not None:
prediction_data = get_inference_data(
config.prediction_loader,
- config.forward_steps_in_memory,
- data_requirements,
+ total_forward_steps=config.n_forward_steps,
+ window_requirements=window_requirements,
+ initial_condition=initial_condition_requirements,
)
-
- timers = run_dataset_comparison(
+ deriver = _Deriver(
+ n_ic_timesteps=stepper_config.n_ic_timesteps,
+ derive_func=stepper.derive_func,
+ )
+ run_dataset_comparison(
aggregator=aggregator,
- normalizer=stepper.normalizer,
prediction_data=prediction_data,
target_data=data,
+ deriver=deriver,
writer=writer,
+ record_logs=record_logs,
)
else:
- timers = run_inference_evaluator(
+ run_inference(
+ predict=stepper.predict_paired,
+ data=data,
aggregator=aggregator,
writer=writer,
- stepper=stepper,
- data=data,
+ record_logs=record_logs,
)
- final_flush_start_time = time.time()
+ timer.start("final_writer_flush")
logging.info("Starting final flush of data writer")
writer.flush()
logging.info("Writing reduced metrics to disk in netcdf format.")
- for name, ds in aggregator.get_datasets(
- ("time_mean", "zonal_mean", "histogram")
- ).items():
- coords = {k: v for k, v in data.coords.items() if k in ds.dims}
- ds = ds.assign_coords(coords)
- ds.to_netcdf(Path(config.experiment_dir) / f"{name}_diagnostics.nc")
-
- final_flush_duration = time.time() - final_flush_start_time
- logging.info(f"Final writer flush duration: {final_flush_duration:.2f} seconds")
- timers["final_writer_flush"] = final_flush_duration
-
- duration = time.time() - start_time
- total_steps = config.n_forward_steps * config.loader.n_samples
- total_steps_per_second = total_steps / duration
- logging.info(f"Inference duration: {duration:.2f} seconds")
- logging.info(f"Total steps per second: {total_steps_per_second:.2f} steps/second")
-
- step_logs = aggregator.get_inference_logs(label="inference")
- wandb = WandB.get_instance()
- if wandb.enabled:
- logging.info("Starting logging of metrics to wandb")
- duration_logs = {
- "duration_seconds": duration,
- "total_steps_per_second": total_steps_per_second,
- }
- wandb.log({**timers, **duration_logs}, step=0)
- for i, log in enumerate(step_logs):
- wandb.log(log, step=i, sleep=0.01)
+ write_reduced_metrics(
+ aggregator,
+ data.coords,
+ config.experiment_dir,
+ excluded=[
+ "video",
+ ],
+ )
+ timer.stop()
+
+ timer.stop_outer("inference")
+ total_steps = config.n_forward_steps * config.loader.n_initial_conditions
+ inference_duration = timer.get_duration("inference")
+ wandb_logging_duration = timer.get_duration("wandb_logging")
+ total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration)
+ timer.log_durations()
+ logging.info(
+ "Total steps per second (ignoring wandb logging): "
+ f"{total_steps_per_second:.2f} steps/second"
+ )
+ summary_logs = {
+ "total_steps_per_second": total_steps_per_second,
+ **timer.get_durations(),
+ **aggregator.get_summary_logs(),
+ }
+ record_logs([summary_logs])
config.clean_wandb()
- return step_logs
-
if __name__ == "__main__":
parser = argparse.ArgumentParser()
diff --git a/fme/fme/ace/inference/inference.py b/fme/fme/ace/inference/inference.py
index 88a9173..6796d45 100644
--- a/fme/fme/ace/inference/inference.py
+++ b/fme/fme/ace/inference/inference.py
@@ -1,10 +1,8 @@
-import argparse
+import copy
import dataclasses
import logging
import os
-import time
-from pathlib import Path
-from typing import Literal, Optional, Sequence, Tuple, Union
+from typing import Literal, Optional, Sequence, Union
import dacite
import torch
@@ -13,24 +11,27 @@
import fme
import fme.core.logging_utils as logging_utils
-from fme.ace.inference.data_writer import DataWriter, DataWriterConfig
-from fme.ace.inference.loop import run_inference
-from fme.core import SingleModuleStepper
-from fme.core.aggregator.inference import InferenceAggregatorConfig
-from fme.core.data_loading.data_typing import GriddedData, SigmaCoordinates
-from fme.core.data_loading.getters import get_forcing_data
-from fme.core.data_loading.inference import (
+from fme.ace.aggregator.inference import InferenceAggregatorConfig
+from fme.ace.data_loading.batch_data import (
+ BatchData,
+ InferenceGriddedData,
+ PrognosticState,
+)
+from fme.ace.data_loading.getters import get_forcing_data
+from fme.ace.data_loading.inference import (
ExplicitIndices,
ForcingDataLoaderConfig,
InferenceInitialConditionIndices,
TimestampList,
)
+from fme.ace.inference.data_writer import DataWriter, DataWriterConfig
+from fme.ace.inference.loop import write_reduced_metrics
+from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig
from fme.core.dicts import to_flat_dict
+from fme.core.generics.inference import get_record_to_wandb, run_inference
from fme.core.logging_utils import LoggingConfig
from fme.core.ocean import OceanConfig
-from fme.core.stepper import SingleModuleStepperConfig
-from fme.core.typing_ import TensorMapping
-from fme.core.wandb import WandB
+from fme.core.timing import GlobalTimer
from .evaluator import load_stepper, load_stepper_config, validate_time_coarsen_config
@@ -44,12 +45,12 @@ class InitialConditionConfig:
.. note::
The data specified under path should contain a time dimension of at least
- length 1. If multiple times are present in the dataset specified by `path`,
+ length 1. If multiple times are present in the dataset specified by ``path``,
the inference will start an ensemble simulation using each IC along a
leading sample dimension. Specific times can be selected from the dataset
- by using `start_indices`.
+ by using ``start_indices``.
- Attributes:
+ Parameters:
path: The path to the initial conditions dataset.
engine: The engine used to open the dataset.
start_indices: optional specification of the subset of
@@ -79,7 +80,7 @@ def _subselect_initial_conditions(self, ds: xr.Dataset) -> xr.Dataset:
def get_initial_condition(
ds: xr.Dataset, prognostic_names: Sequence[str]
-) -> Tuple[TensorMapping, xr.DataArray]:
+) -> PrognosticState:
"""Given a dataset, extract a mapping of variables to tensors.
and the time coordinate corresponding to the initial conditions.
@@ -90,7 +91,7 @@ def get_initial_condition(
prognostic_names: Names of prognostic variables to extract from the dataset.
Returns:
- A mapping of variable names to tensors and the time coordinate.
+ The initial condition and the time coordinate.
"""
initial_condition = {}
for name in prognostic_names:
@@ -100,16 +101,26 @@ def get_initial_condition(
f"(n_samples, n_lat, n_lon). Got shape {ds[name].shape}."
)
n_samples = ds[name].shape[0]
- initial_condition[name] = torch.tensor(ds[name].values).to(fme.get_device())
+ initial_condition[name] = torch.tensor(ds[name].values).unsqueeze(dim=1)
if "time" not in ds:
raise ValueError("Initial condition dataset must have a 'time' variable.")
- initial_times = ds.time
- if len(initial_times) != n_samples:
+ initial_times = xr.DataArray(
+ data=ds.time.values[:, None],
+ dims=["sample", "time"],
+ )
+ if initial_times.shape[0] != n_samples:
raise ValueError(
"Length of 'time' variable must match first dimension of variables "
- f"in initial condition dataset. Got {len(initial_times)} and {n_samples}."
+ f"in initial condition dataset. Got {initial_times.shape[0]} "
+ f"and {n_samples}."
)
- return initial_condition, initial_times
+
+ batch_data = BatchData.new_on_cpu(
+ data=initial_condition,
+ time=initial_times,
+ horizontal_dims=["lat", "lon"],
+ )
+ return batch_data.get_start(prognostic_names, n_ic_timesteps=1)
@dataclasses.dataclass
@@ -117,7 +128,7 @@ class InferenceConfig:
"""
Configuration for running inference.
- Attributes:
+ Parameters:
experiment_dir: Directory to save results to.
n_forward_steps: Number of steps to run the model forward for.
checkpoint_path: Path to stepper checkpoint to load.
@@ -167,46 +178,28 @@ def configure_wandb(self, env_vars: Optional[dict] = None, **kwargs):
def clean_wandb(self):
self.logging.clean_wandb(self.experiment_dir)
- def configure_gcs(self):
- self.logging.configure_gcs()
-
- def load_stepper(
- self, area: torch.Tensor, sigma_coordinates: SigmaCoordinates
- ) -> SingleModuleStepper:
- """
- Args:
- area: A tensor of shape (n_lat, n_lon) containing the area of
- each grid cell.
- sigma_coordinates: The sigma coordinates of the model.
- """
+ def load_stepper(self) -> SingleModuleStepper:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
- stepper = load_stepper(
- self.checkpoint_path,
- area=area,
- sigma_coordinates=sigma_coordinates,
- ocean_config=self.ocean,
- )
+ stepper = load_stepper(self.checkpoint_path, ocean_config=self.ocean)
return stepper
def load_stepper_config(self) -> SingleModuleStepperConfig:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
- return load_stepper_config(self.checkpoint_path)
+ return load_stepper_config(self.checkpoint_path, ocean_config=self.ocean)
- def get_data_writer(
- self, data: GriddedData, prognostic_names: Sequence[str]
- ) -> DataWriter:
+ def get_data_writer(self, data: InferenceGriddedData) -> DataWriter:
return self.data_writer.build(
experiment_dir=self.experiment_dir,
- n_samples=data.loader.dataset.n_samples,
+ # each batch contains all samples, for different times
+ n_initial_conditions=data.n_initial_conditions,
n_timesteps=self.n_forward_steps,
timestep=data.timestep,
- prognostic_names=prognostic_names,
- metadata=data.metadata,
+ variable_metadata=data.variable_metadata,
coords=data.coords,
)
-def main(yaml_config: str):
+def main(yaml_config: str, segments: Optional[int] = None):
with open(yaml_config, "r") as f:
data = yaml.safe_load(f)
config = dacite.from_dict(
@@ -218,10 +211,19 @@ def main(yaml_config: str):
os.makedirs(config.experiment_dir, exist_ok=True)
with open(os.path.join(config.experiment_dir, "config.yaml"), "w") as f:
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
- run_inference_from_config(config)
+ if segments is None:
+ with GlobalTimer():
+ return run_inference_from_config(config)
+ else:
+ config.configure_logging(log_filename="inference_out.log")
+ run_segmented_inference(config, segments)
def run_inference_from_config(config: InferenceConfig):
+ timer = GlobalTimer.get_instance()
+ timer.start_outer("inference")
+ timer.start("initialization")
+
if not os.path.isdir(config.experiment_dir):
os.makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
@@ -234,86 +236,109 @@ def run_inference_from_config(config: InferenceConfig):
logging_utils.log_versions()
logging.info(f"Current device is {fme.get_device()}")
- start_time = time.time()
stepper_config = config.load_stepper_config()
- data_requirements = stepper_config.get_forcing_data_requirements(
- n_forward_steps=config.n_forward_steps
+ data_requirements = stepper_config.get_forcing_window_data_requirements(
+ n_forward_steps=config.forward_steps_in_memory
)
logging.info("Loading initial condition data")
- initial_condition, initial_times = get_initial_condition(
+ initial_condition = get_initial_condition(
config.initial_condition.get_dataset(), stepper_config.prognostic_names
)
+ stepper = config.load_stepper()
logging.info("Initializing forcing data loaded")
data = get_forcing_data(
- config.forcing_loader,
- config.forward_steps_in_memory,
- data_requirements,
- initial_times,
- )
-
- stepper = config.load_stepper(
- data.area_weights.to(fme.get_device()),
- sigma_coordinates=data.sigma_coordinates.to(fme.get_device()),
+ config=config.forcing_loader,
+ total_forward_steps=config.n_forward_steps,
+ window_requirements=data_requirements,
+ initial_condition=initial_condition,
+ surface_temperature_name=stepper.surface_temperature_name,
+ ocean_fraction_name=stepper.ocean_fraction_name,
)
if stepper.timestep != data.timestep:
raise ValueError(
f"Timestep of the loaded stepper, {stepper.timestep}, does not "
f"match that of the forcing data, {data.timestep}."
)
-
aggregator = config.aggregator.build(
- area_weights=data.area_weights.to(fme.get_device()),
- sigma_coordinates=data.sigma_coordinates,
- timestep=data.timestep,
- n_timesteps=config.n_forward_steps + 1,
- metadata=data.metadata,
+ gridded_operations=data.gridded_operations,
+ n_timesteps=config.n_forward_steps + stepper.n_ic_timesteps,
+ variable_metadata=data.variable_metadata,
)
- writer = config.get_data_writer(data, stepper.prognostic_names)
+ writer = config.get_data_writer(data)
+ timer.stop()
logging.info("Starting inference")
- timers = run_inference(
- stepper=stepper,
- initial_condition=initial_condition,
- forcing_data=data,
+ record_logs = get_record_to_wandb(label="inference")
+ run_inference(
+ predict=stepper.predict,
+ data=data,
writer=writer,
aggregator=aggregator,
+ record_logs=record_logs,
)
- final_flush_start_time = time.time()
+ timer.start("final_writer_flush")
logging.info("Starting final flush of data writer")
writer.flush()
- for name, ds in aggregator.get_datasets(("time_mean",)).items():
- coords = {k: v for k, v in data.coords.items() if k in ds.dims}
- ds = ds.assign_coords(coords)
- ds.to_netcdf(Path(config.experiment_dir) / f"{name}_diagnostics.nc")
- final_flush_duration = time.time() - final_flush_start_time
- logging.info(f"Final writer flush duration: {final_flush_duration:.2f} seconds")
- timers["final_writer_flush"] = final_flush_duration
-
- duration = time.time() - start_time
- total_steps = config.n_forward_steps * data.loader.dataset.n_samples
- total_steps_per_second = total_steps / duration
- logging.info(f"Inference duration: {duration:.2f} seconds")
- logging.info(f"Total steps per second: {total_steps_per_second:.2f} steps/second")
-
- step_logs = aggregator.get_inference_logs(label="inference")
- wandb = WandB.get_instance()
- if wandb.enabled:
- logging.info("Starting logging of metrics to wandb")
- duration_logs = {
- "duration_seconds": duration,
- "total_steps_per_second": total_steps_per_second,
- }
- wandb.log({**timers, **duration_logs}, step=0)
- for i, log in enumerate(step_logs):
- wandb.log(log, step=i, sleep=0.01)
+ logging.info("Writing reduced metrics to disk in netcdf format.")
+ write_reduced_metrics(aggregator, data.coords, config.experiment_dir)
+ timer.stop()
+
+ timer.stop_outer("inference")
+ total_steps = config.n_forward_steps * data.n_initial_conditions
+ inference_duration = timer.get_duration("inference")
+ wandb_logging_duration = timer.get_duration("wandb_logging")
+ total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration)
+ timer.log_durations()
+ logging.info(
+ "Total steps per second (ignoring wandb logging): "
+ f"{total_steps_per_second:.2f} steps/second"
+ )
+ summary_logs = {
+ "total_steps_per_second": total_steps_per_second,
+ **timer.get_durations(),
+ **aggregator.get_summary_logs(),
+ }
+ record_logs([summary_logs])
config.clean_wandb()
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("yaml_config", type=str)
- args = parser.parse_args()
- main(yaml_config=args.yaml_config)
+def run_segmented_inference(config: InferenceConfig, segments: int):
+ """Run inference in multiple segments.
+
+ Args:
+ config: inference configuration to be used for each individual segment. The
+ provided initial condition configuration will only be used for the first
+ segment.
+ segments: total number of segments desired. Only missing segments will be run.
+
+ Note:
+ This is useful when running very long simulations or when saving a large
+ amount of output data to disk. The simulation outputs will be split across
+ multiple folders, each corresponding to one of the segments and labeled by
+ the segment number.
+ """
+ logging.info(
+ f"Starting segmented inference with {segments} segments. "
+ f"Saving to {config.experiment_dir}."
+ )
+ config_copy = copy.deepcopy(config)
+ original_wandb_name = os.environ.get("WANDB_NAME")
+ for segment in range(segments):
+ segment_label = f"segment_{segment:04d}"
+ segment_dir = os.path.join(config.experiment_dir, segment_label)
+ restart_path = os.path.join(segment_dir, "restart.nc")
+ if os.path.exists(restart_path):
+ logging.info(f"Skipping segment {segment} because it has already been run.")
+ else:
+ logging.info(f"Running segment {segment}.")
+ config_copy.experiment_dir = segment_dir
+ if original_wandb_name is not None:
+ os.environ["WANDB_NAME"] = f"{original_wandb_name}-{segment_label}"
+ with GlobalTimer():
+ run_inference_from_config(config_copy)
+ config_copy.initial_condition = InitialConditionConfig(
+ path=restart_path, engine="netcdf4"
+ )
diff --git a/fme/fme/ace/inference/loop.py b/fme/fme/ace/inference/loop.py
index 00f8118..87e4c33 100644
--- a/fme/fme/ace/inference/loop.py
+++ b/fme/fme/ace/inference/loop.py
@@ -1,376 +1,113 @@
+import abc
import logging
-import time
-from collections import defaultdict
-from typing import Any, Dict, Mapping, Optional, Union
+from pathlib import Path
+from typing import Callable, Iterable, Mapping, Optional, Union
-import torch
-import xarray as xr
+import numpy as np
-from fme.core import SingleModuleStepper
-from fme.core.aggregator.inference.main import (
+from fme.ace.aggregator.inference.main import (
InferenceAggregator,
InferenceEvaluatorAggregator,
)
-from fme.core.data_loading.data_typing import GriddedData
-from fme.core.data_loading.utils import BatchData
-from fme.core.device import get_device
-from fme.core.normalizer import StandardNormalizer
-from fme.core.optimization import NullOptimization
-from fme.core.stepper import SteppedData
-from fme.core.typing_ import TensorMapping
-
-from .data_writer import DataWriter, NullDataWriter, PairedDataWriter
-from .derived_variables import (
- compute_derived_quantities,
- compute_stepped_derived_quantities,
-)
+from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
+from fme.ace.inference.data_writer import PairedDataWriter
+from fme.core.generics.aggregator import InferenceAggregatorABC, InferenceLogs
+from fme.core.generics.data import InferenceDataABC
+from fme.core.generics.inference import get_record_to_wandb
+from fme.core.generics.writer import NullDataWriter
+from fme.core.timing import GlobalTimer
-class WindowStitcher:
+class DeriverABC(abc.ABC):
"""
- Handles stitching together the windows of data from the inference loop.
-
- For example, handles passing in windows to data writers which combine
- them together into a continuous series, and handles storing prognostic
- variables from the end of a window to use as the initial condition for
- the next window.
+ Abstract base class for processing data during dataset comparison.
"""
- def __init__(
- self,
- n_forward_steps: int,
- writer: Union[PairedDataWriter, NullDataWriter],
- ):
- self.i_time = 0
- self.n_forward_steps = n_forward_steps
- self.writer = writer
- # tensors have shape [n_sample, n_lat, n_lon] with no time axis
- self._initial_condition: Optional[Mapping[str, torch.Tensor]] = None
-
- def append(
- self,
- data: Dict[str, torch.tensor],
- gen_data: Dict[str, torch.tensor],
- batch_times: xr.DataArray,
- ) -> None:
- """
- Appends a time segment of data to the ensemble batch.
-
- Args:
- data: The reference data for the current time segment, tensors
- should have shape [n_sample, n_time, n_lat, n_lon]
- gen_data: The generated data for the current time segment, tensors
- should have shape [n_sample, n_time, n_lat, n_lon]
- batch_times: Time coordinates for each sample in the batch.
- """
- tensor_shape = next(data.values().__iter__()).shape
- self.writer.append_batch(
- target=data,
- prediction=gen_data,
- start_timestep=self.i_time,
- batch_times=batch_times,
- )
- self.i_time += tensor_shape[1]
- self._initial_condition = {key: value[:, -1] for key, value in data.items()}
- for key, value in gen_data.items():
- self._initial_condition[key] = value[:, -1]
- for key, value in self._initial_condition.items():
- self._initial_condition[key] = value.detach().cpu()
-
- def apply_initial_condition(self, data: Mapping[str, torch.Tensor]):
- """
- Applies the last recorded state of the batch as the initial condition for
- the next segment of the timeseries.
+ @abc.abstractmethod
+ def get_forward_data(
+ self, data: BatchData, compute_derived_variables: bool = False
+ ) -> BatchData: ...
- Args:
- data: The data to apply the initial condition to, tensors should have
- shape [n_sample, n_time, n_lat, n_lon] and the first value along
- the time axis will be replaced with the last value from the
- previous segment.
- """
- if self.i_time > self.n_forward_steps:
- raise ValueError(
- "Cannot apply initial condition after "
- "the last segment has been appended, currently at "
- f"time index {self.i_time} "
- f"with {self.n_forward_steps} max forward steps."
- )
- if self._initial_condition is not None:
- for key, value in data.items():
- value[:, 0] = self._initial_condition[key].to(value.device)
-
- def save_initial_condition(
- self,
- ic_data: Dict[str, torch.Tensor],
- ic_time: xr.DataArray,
- ):
- self.writer.save_initial_condition(ic_data, ic_time)
+ @property
+ @abc.abstractmethod
+ def n_ic_timesteps(self) -> int: ...
-def _inference_internal_loop(
- stepped: SteppedData,
- i_time: int,
- aggregator: InferenceEvaluatorAggregator,
- stitcher: WindowStitcher,
- batch_times: xr.DataArray,
+def write_reduced_metrics(
+ aggregator: Union[InferenceEvaluatorAggregator, InferenceAggregator],
+ data_coords: Mapping[str, np.ndarray],
+ path: str,
+ excluded: Optional[Iterable[str]] = None,
):
- """Do operations that need to be done on each time step of the inference loop.
-
- This function exists to de-duplicate code between run_inference_evaluator and
- run_dataset_comparison."""
-
- # The first data window includes the IC, while subsequent windows don't.
- # The aggregators use the full first window including IC.
- # The data writers exclude the IC from the first window.
- if i_time == 0:
- i_time_aggregator = i_time
- stepped_no_ic = stepped.remove_initial_condition()
- stitcher.save_initial_condition(
- ic_data={k: v[:, 0] for k, v in stepped.target_data.items()},
- ic_time=batch_times.isel(time=0),
- )
- batch_times_no_ic = batch_times.isel(time=slice(1, None))
- else:
- i_time_aggregator = i_time + 1
- stepped_no_ic = stepped
- batch_times_no_ic = batch_times
-
- # record raw data for the batch, and store the final state
- # for the next segment
- # Do not include the initial condition in the data writers
- stitcher.append(
- stepped_no_ic.target_data, stepped_no_ic.gen_data, batch_times_no_ic
- )
-
- # record metrics, includes the initial condition
- aggregator.record_batch(
- loss=float(stepped.metrics["loss"]),
- time=batch_times,
- target_data=stepped.target_data,
- gen_data=stepped.gen_data,
- target_data_norm=stepped.target_data_norm,
- gen_data_norm=stepped.gen_data_norm,
- i_time_start=i_time_aggregator,
- )
-
-
-def _to_device(
- data: Mapping[str, torch.Tensor], device: torch.device
-) -> Dict[str, Any]:
- return {key: value.to(device) for key, value in data.items()}
-
-
-def run_inference(
- stepper: SingleModuleStepper,
- initial_condition: TensorMapping,
- forcing_data: GriddedData,
- writer: DataWriter,
- aggregator: InferenceAggregator,
-) -> Dict[str, float]:
- """Run extended inference loop given initial condition and forcing data.
+ """
+ Write the reduced metrics to disk. Each sub-aggregator will write a netCDF file
+ if its `get_dataset` method returns a non-empty dataset.
Args:
- stepper: The model to run inference with.
- initial_condition: Mapping of prognostic names to initial condition tensors of
- shape (n_sample, n_lat, n_lon).
- forcing_data: GriddedData object which includes a DataLoader which will provide
- windows of forcing data appropriately aligned with the initial condition.
- writer: Data writer for saving the inference results to disk.
-
- Returns:
- Execution time in seconds for each step of the inference loop.
+ aggregator: The aggregator to write metrics from.
+ data_coords: Coordinates to assign to the datasets.
+ path: Path to write the metrics to.
+ excluded: Names of metrics to exclude from writing.
"""
- with torch.no_grad():
- timers: Dict[str, float] = defaultdict(float)
- current_time = time.time()
- i_time = 0
- window_forcing: BatchData
- for window_forcing in forcing_data.loader:
- timers["data_loading"] += time.time() - current_time
- current_time = time.time()
- forward_steps_in_memory = list(window_forcing.data.values())[0].size(1) - 1
- logging.info(
- f"Inference: starting window spanning {i_time}"
- f" to {i_time + forward_steps_in_memory} steps."
- )
- window_forcing_data = _to_device(window_forcing.data, get_device())
- prediction = stepper.predict(
- initial_condition, window_forcing_data, forward_steps_in_memory
- )
- timers["run_on_batch"] += time.time() - current_time
-
- # Replicates the timestep range of the stepped.target_data
- # used in the evaluator computation of derived quantities, which drops
- # the initial condition.
- forcing_data_at_prediction_steps = {
- k: window_forcing_data[k][:, 1:] for k in window_forcing_data
- }
- prediction = compute_derived_quantities(
- prediction,
- forcing_data.sigma_coordinates,
- forcing_data.timestep,
- forcing_data=forcing_data_at_prediction_steps,
- )
-
- forward_times = window_forcing.times.isel(time=slice(1, None))
- writer.append_batch(prediction, i_time, forward_times)
- aggregator.record_batch(
- time=forward_times, data=prediction, i_time_start=i_time + 1
- )
- timers["writer_and_aggregator"] += time.time() - current_time
- current_time = time.time()
- initial_condition = {
- k: prediction[k][:, -1] for k in stepper.prognostic_names
- }
- i_time += forward_steps_in_memory
-
- for name, duration in timers.items():
- logging.info(f"{name} duration: {duration:.2f}s")
- return timers
-
-
-def run_inference_evaluator(
- aggregator: InferenceEvaluatorAggregator,
- stepper: SingleModuleStepper,
- data: GriddedData,
- writer: Optional[Union[PairedDataWriter, NullDataWriter]] = None,
-) -> Dict[str, float]:
- if writer is None:
- writer = NullDataWriter()
- n_forward_steps = data.loader.dataset.n_forward_steps
- stitcher = WindowStitcher(n_forward_steps, writer)
-
- with torch.no_grad():
- # We have data batches with long windows, where all data for a
- # given batch does not fit into memory at once, so we window it in time
- # and run the model on each window in turn.
- #
- # We process each time window and keep track of the
- # final state. We then use this as the initial condition
- # for the next time window.
-
- timers: Dict[str, float] = defaultdict(float)
- current_time = time.time()
- i_time = 0
- for i, window_batch_data in enumerate(data.loader):
- timers["data_loading"] += time.time() - current_time
- current_time = time.time()
- forward_steps_in_memory = (
- list(window_batch_data.data.values())[0].size(1) - 1
- )
- logging.info(
- f"Inference: starting window spanning {i_time}"
- f" to {i_time + forward_steps_in_memory} steps, "
- f"out of total {n_forward_steps}."
- )
- device = get_device()
- window_data = _to_device(window_batch_data.data, device)
-
- stitcher.apply_initial_condition(window_data)
-
- stepped = stepper.run_on_batch(
- window_data,
- NullOptimization(),
- n_forward_steps=forward_steps_in_memory,
- )
-
- # Prepend initial (pre-first-timestep) output for the first window
- if i == 0:
- (
- initial_condition,
- normed_initial_condition,
- ) = stepper.get_initial_condition(window_data)
- stepped = stepped.prepend_initial_condition(
- initial_condition, normed_initial_condition
- )
- batch_times = window_batch_data.times
- else:
- batch_times = window_batch_data.times.isel(time=slice(1, None))
- stepped = compute_stepped_derived_quantities(
- stepped,
- data.sigma_coordinates,
- data.timestep,
- # forcing inputs are in target data but not gen_data
- forcing_data=stepped.target_data,
- )
- timers["run_on_batch"] += time.time() - current_time
- current_time = time.time()
- _inference_internal_loop(stepped, i_time, aggregator, stitcher, batch_times)
- timers["writer_and_aggregator"] += time.time() - current_time
- current_time = time.time()
- i_time += forward_steps_in_memory
-
- for name, duration in timers.items():
- logging.info(f"{name} duration: {duration:.2f}s")
- return timers
+ for name, ds in aggregator.get_datasets(excluded_aggregators=excluded).items():
+ if len(ds) > 0:
+ coords = {k: v for k, v in data_coords.items() if k in ds.dims}
+ ds = ds.assign_coords(coords)
+ ds.to_netcdf(Path(path) / f"{name}_diagnostics.nc")
def run_dataset_comparison(
- aggregator: InferenceEvaluatorAggregator,
- normalizer: StandardNormalizer,
- prediction_data: GriddedData,
- target_data: GriddedData,
+ aggregator: InferenceAggregatorABC[PairedData, PairedData],
+ prediction_data: InferenceDataABC[PrognosticState, BatchData],
+ target_data: InferenceDataABC[PrognosticState, BatchData],
+ deriver: DeriverABC,
writer: Optional[Union[PairedDataWriter, NullDataWriter]] = None,
-) -> Dict[str, float]:
+ record_logs: Optional[Callable[[InferenceLogs], None]] = None,
+):
+ if record_logs is None:
+ record_logs = get_record_to_wandb(label="inference")
if writer is None:
writer = NullDataWriter()
- n_forward_steps = target_data.loader.dataset.n_forward_steps
- stitcher = WindowStitcher(n_forward_steps, writer)
- device = get_device()
- # We have data batches with long windows, where all data for a
- # given batch does not fit into memory at once, so we window it in time
- # and run the model on each window in turn.
- #
- # We process each time window and keep track of the
- # final state. We then use this as the initial condition
- # for the next time window.
- timers: Dict[str, float] = defaultdict(float)
- current_time = time.time()
+ timer = GlobalTimer.get_instance()
+ timer.start("data_loading")
i_time = 0
+ n_windows = min(len(prediction_data.loader), len(target_data.loader))
for i, (pred, target) in enumerate(zip(prediction_data.loader, target_data.loader)):
- timers["data_loading"] += time.time() - current_time
- current_time = time.time()
+ timer.stop()
+ if i_time == 0:
+ with timer.context("aggregator"):
+ logs = aggregator.record_initial_condition(
+ initial_condition=PairedData.from_batch_data(
+ prediction=prediction_data.initial_condition.as_batch_data(),
+ target=target_data.initial_condition.as_batch_data(),
+ ),
+ )
+ with timer.context("wandb_logging"):
+ record_logs(logs)
+
forward_steps_in_memory = list(pred.data.values())[0].size(1) - 1
logging.info(
- f"Inference: starting window spanning {i_time}"
- f" to {i_time + forward_steps_in_memory} steps,"
- f" out of total {n_forward_steps}."
- )
- pred_window_data = _to_device(pred.data, device)
- target_window_data = _to_device(target.data, device)
- stepped = SteppedData(
- {"loss": torch.tensor(float("nan"))},
- pred_window_data,
- target_window_data,
- normalizer.normalize(pred_window_data),
- normalizer.normalize(target_window_data),
- )
- stepped = compute_stepped_derived_quantities(
- stepped, target_data.sigma_coordinates, target_data.timestep
+ f"Inference: Processing window {i + 1} of {n_windows}"
+ f" spanning {i_time} to {i_time + forward_steps_in_memory} steps."
)
+ pred = deriver.get_forward_data(pred, compute_derived_variables=True)
+ target = deriver.get_forward_data(target, compute_derived_variables=True)
+ paired_data = PairedData.from_batch_data(prediction=pred, target=target)
+
+ with timer.context("data_writer"):
+ writer.append_batch(
+ batch=paired_data,
+ )
+ with timer.context("aggregator"):
+ logs = aggregator.record_batch(
+ data=paired_data,
+ )
- # Windows here all include an initial condition at start.
- # Remove IC and time coord for windows >0 to be consistent with
- # run_on_batch outputs before passing to the shared _inference_internal_loop.
- if i > 0:
- stepped = stepped.remove_initial_condition()
- target_times = target.times.isel(time=slice(1, None))
- else:
- target_times = target.times
+ with timer.context("wandb_logging"):
+ record_logs(logs)
- timers["run_on_batch"] += time.time() - current_time
- current_time = time.time()
- _inference_internal_loop(
- stepped,
- i_time,
- aggregator,
- stitcher,
- target_times,
- )
- timers["writer_and_aggregator"] += time.time() - current_time
- current_time = time.time()
+ timer.start("data_loading")
i_time += forward_steps_in_memory
- for name, duration in timers.items():
- logging.info(f"{name} duration: {duration:.2f}s")
- return timers
+
+ timer.stop()
diff --git a/fme/fme/ace/inference/stepper_test_data b/fme/fme/ace/inference/stepper_test_data
index b9fc36c..4c3a594 100644
Binary files a/fme/fme/ace/inference/stepper_test_data and b/fme/fme/ace/inference/stepper_test_data differ
diff --git a/fme/fme/ace/inference/test_derived_variables.py b/fme/fme/ace/inference/test_derived_variables.py
index f06b3b0..4c93012 100644
--- a/fme/fme/ace/inference/test_derived_variables.py
+++ b/fme/fme/ace/inference/test_derived_variables.py
@@ -1,70 +1,114 @@
import datetime
+import numpy as np
import pytest
import torch
+import xarray as xr
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.stepper import SteppedData
+from fme.ace.stepper import TrainOutput
+from fme.core.climate_data import ClimateData
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.typing_ import TensorDict, TensorMapping
-from .derived_variables import (
- DerivedVariableRegistryEntry,
- _compute_derived_variable,
- compute_stepped_derived_quantities,
-)
+from .derived_variables import _compute_derived_variable, compute_derived_quantities
TIMESTEP = datetime.timedelta(hours=6)
def test_compute_derived_variable():
fake_data = {"PRESsfc": torch.tensor([1.0]), "PRATEsfc": torch.tensor([2.0])}
- sigma_coordinates = None
- derived_variable = DerivedVariableRegistryEntry(
- func=lambda data, *_: data.surface_pressure + data.precipitation_rate
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.tensor([0.0, 0.0]), bk=torch.tensor([0.0, 1.0])
)
+
+ def _derived_variable_func(data: ClimateData, *_) -> torch.Tensor:
+ return data.surface_pressure + data.precipitation_rate
+
output_data = _compute_derived_variable(
- fake_data, sigma_coordinates, TIMESTEP, "c", derived_variable
+ fake_data, vertical_coordinate, TIMESTEP, "c", _derived_variable_func
)
torch.testing.assert_close(output_data["c"], torch.tensor([3.0]))
def test_compute_derived_variable_raises_value_error_when_overwriting():
fake_data = {"PRESsfc": torch.tensor([1.0]), "PRATEsfc": torch.tensor([2.0])}
- sigma_coordinates = None
- derived_variable = DerivedVariableRegistryEntry(
- func=lambda data, _: data.surface_pressure + data.precipitation_rate
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.tensor([0.0, 0.0]), bk=torch.tensor([0.0, 1.0])
)
+
+ def add_surface_pressure_and_precipitation(data: ClimateData, *_) -> torch.Tensor:
+ return data.surface_pressure + data.precipitation_rate
+
+ derived_variable_func = add_surface_pressure_and_precipitation
with pytest.raises(ValueError):
_compute_derived_variable(
- fake_data, sigma_coordinates, TIMESTEP, "PRATEsfc", derived_variable
+ fake_data, vertical_coordinate, TIMESTEP, "PRATEsfc", derived_variable_func
)
-def test_compute_derived_quantities():
+@pytest.mark.parametrize("dataset", ["fv3", "e3sm"])
+def test_compute_derived_quantities(dataset: str):
torch.manual_seed(0)
- fake_data = {
- "PRESsfc": 10.0 + torch.rand(2, 3, 4, 8),
- "specific_total_water_0": torch.rand(2, 3, 4, 8),
- "specific_total_water_1": torch.rand(2, 3, 4, 8),
- "PRATEsfc": torch.rand(2, 3, 4, 8),
- "LHTFLsfc": torch.rand(2, 3, 4, 8),
- "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8),
- }
- data = SteppedData(
- metrics={"loss": torch.tensor(0.0)},
- gen_data=fake_data,
- target_data=fake_data,
- gen_data_norm=fake_data,
- target_data_norm=fake_data,
- )
- sigma_coordinates = SigmaCoordinates(
+
+ if dataset == "fv3":
+ fake_data = {
+ "PRESsfc": 10.0 + torch.rand(2, 3, 4, 8),
+ "specific_total_water_0": torch.rand(2, 3, 4, 8),
+ "specific_total_water_1": torch.rand(2, 3, 4, 8),
+ "PRATEsfc": torch.rand(2, 3, 4, 8),
+ "LHTFLsfc": torch.rand(2, 3, 4, 8),
+ "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8),
+ "DSWRFtoa": torch.rand(2, 3, 4, 8),
+ "USWRFtoa": torch.rand(2, 3, 4, 8),
+ "ULWRFtoa": torch.rand(2, 3, 4, 8),
+ }
+ gen_data = fake_data.copy()
+ del gen_data["DSWRFtoa"]
+
+ if dataset == "e3sm":
+ fake_data = {
+ "PS": 10.0 + torch.rand(2, 3, 4, 8),
+ "specific_total_water_0": torch.rand(2, 3, 4, 8),
+ "specific_total_water_1": torch.rand(2, 3, 4, 8),
+ "surface_precipitation_rate": torch.rand(2, 3, 4, 8),
+ "LHFLX": torch.rand(2, 3, 4, 8),
+ "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8),
+ "SOLIN": torch.rand(2, 3, 4, 8),
+ "top_of_atmos_upward_shortwave_flux": torch.rand(2, 3, 4, 8),
+ "FLUT": torch.rand(2, 3, 4, 8),
+ }
+ gen_data = fake_data.copy()
+ del gen_data["SOLIN"]
+
+ vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.tensor([0.0, 0.5, 0.0]),
bk=torch.tensor([0.0, 0.5, 1.0]),
)
- out_data = compute_stepped_derived_quantities(data, sigma_coordinates, TIMESTEP)
+
+ def derive_func(data: TensorMapping, forcing_data: TensorMapping) -> TensorDict:
+ updated = compute_derived_quantities(
+ dict(data),
+ vertical_coordinate=vertical_coordinate,
+ timestep=TIMESTEP,
+ forcing_data=dict(forcing_data),
+ )
+ return updated
+
+ data = TrainOutput(
+ metrics={"loss": torch.tensor(0.0)},
+ gen_data=gen_data,
+ target_data=fake_data,
+ time=xr.DataArray(np.zeros((2, 3)), dims=["sample", "time"]),
+ normalize=lambda x: x,
+ derive_func=derive_func,
+ )
+ out_data = data.compute_derived_variables()
for name in (
"total_water_path_budget_residual",
"total_water_path",
"surface_pressure_due_to_dry_air",
+ "surface_pressure_due_to_dry_air_absolute_tendency",
+ "net_energy_flux_toa_into_atmosphere",
):
assert name in out_data.gen_data
assert name in out_data.target_data
diff --git a/fme/fme/ace/inference/test_evaluator.py b/fme/fme/ace/inference/test_evaluator.py
index a2227ae..99cad2a 100644
--- a/fme/fme/ace/inference/test_evaluator.py
+++ b/fme/fme/ace/inference/test_evaluator.py
@@ -1,8 +1,8 @@
-import contextlib
import dataclasses
import datetime
+import os
import pathlib
-from typing import List, Tuple
+from typing import List
import dacite
import numpy as np
@@ -11,25 +11,29 @@
import xarray as xr
import yaml
+from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig
+from fme.ace.data_loading.inference import (
+ InferenceDataLoaderConfig,
+ InferenceInitialConditionIndices,
+)
from fme.ace.inference.data_writer import DataWriterConfig
from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig
-from fme.ace.inference.derived_variables import compute_stepped_derived_quantities
from fme.ace.inference.evaluator import InferenceEvaluatorConfig, main
from fme.ace.registry import ModuleSelector
+from fme.ace.stepper import SingleModuleStepperConfig, TrainOutput
+from fme.ace.testing import DimSizes, FV3GFSData, MonthlyReferenceData
from fme.core import metrics
-from fme.core.aggregator.inference import InferenceEvaluatorAggregatorConfig, annual
-from fme.core.data_loading.config import XarrayDataConfig
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.data_loading.inference import (
- InferenceDataLoaderConfig,
- InferenceInitialConditionIndices,
-)
+from fme.core.coordinates import DimSize, HybridSigmaPressureCoordinate
+from fme.core.dataset.config import XarrayDataConfig
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from fme.core.logging_utils import LoggingConfig
-from fme.core.normalizer import FromStateNormalizer
+from fme.core.normalizer import NormalizationConfig
from fme.core.ocean import Ocean, OceanConfig
-from fme.core.stepper import SingleModuleStepperConfig, SteppedData
-from fme.core.testing import DimSizes, FV3GFSData, MonthlyReferenceData, mock_wandb
+from fme.core.testing import mock_wandb
+from fme.core.typing_ import TensorDict, TensorMapping
+
+from .derived_variables import compute_derived_quantities
DIR = pathlib.Path(__file__).parent
TIMESTEP = datetime.timedelta(hours=6)
@@ -40,41 +44,34 @@ def forward(self, x):
return x + 1
-@contextlib.contextmanager
-def patch_annual_aggregator_min_samples(value):
- original = annual.MIN_SAMPLES
- try:
- annual.MIN_SAMPLES = value
- yield
- finally:
- annual.MIN_SAMPLES = original
-
-
def save_plus_one_stepper(
path: pathlib.Path,
- names: List[str],
+ in_names: List[str],
+ out_names: List[str],
mean: float,
std: float,
- data_shape: Tuple[int, int, int],
+ data_shape: List[int],
timestep: datetime.timedelta = TIMESTEP,
+ nz_interface: int = 7,
):
+ all_names = list(set(in_names).union(out_names))
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}),
- in_names=["var"],
- out_names=["var"],
- normalization=FromStateNormalizer(
- state={
- "means": {name: mean for name in names},
- "stds": {name: std for name in names},
- }
+ in_names=in_names,
+ out_names=out_names,
+ normalization=NormalizationConfig(
+ means={name: mean for name in all_names},
+ stds={name: std for name in all_names},
),
)
area = torch.ones(data_shape[-2:], device=get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(nz_interface), bk=torch.arange(nz_interface)
+ )
stepper = config.get_stepper(
- img_shape=data_shape[-2:],
- area=area,
- sigma_coordinates=sigma_coordinates,
+ img_shape=(data_shape[-2], data_shape[-1]),
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=timestep,
)
torch.save({"stepper": stepper.get_state()}, path)
@@ -87,12 +84,12 @@ def test_inference_backwards_compatibility(tmp_path: pathlib.Path):
"""
in_names = ["var"]
out_names = ["var"]
- all_names = list(set(in_names).union(out_names))
stepper_path = DIR / "stepper_test_data"
+
+ horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)]
dim_sizes = DimSizes(
n_time=8,
- n_lat=4,
- n_lon=8,
+ horizontal=horizontal,
nz_interface=2,
)
std = 1.0
@@ -101,17 +98,19 @@ def test_inference_backwards_compatibility(tmp_path: pathlib.Path):
# to re-generate, just delete the data and run the test (it will fail)
save_plus_one_stepper(
stepper_path,
- names=all_names,
+ in_names,
+ out_names,
mean=0.0,
std=std,
- data_shape=dim_sizes.shape_2d,
+ data_shape=dim_sizes.shape_nd,
)
assert False, "stepper_test_data did not exist, it has been created"
use_prediction_data = False
n_forward_steps = 2
inference_helper(
tmp_path,
- all_names,
+ in_names,
+ out_names,
use_prediction_data,
dim_sizes,
n_forward_steps,
@@ -129,12 +128,12 @@ def test_inference_plus_one_model(
):
in_names = ["var"]
out_names = ["var"]
- all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper"
+
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
dim_sizes = DimSizes(
n_time=n_forward_steps + 1,
- n_lat=16,
- n_lon=32,
+ horizontal=horizontal,
nz_interface=4,
)
if use_prediction_data:
@@ -144,15 +143,17 @@ def test_inference_plus_one_model(
std = 1.0
save_plus_one_stepper(
stepper_path,
- names=all_names,
+ in_names,
+ out_names,
mean=0.0,
std=std,
- data_shape=dim_sizes.shape_2d,
+ data_shape=dim_sizes.shape_nd,
timestep=datetime.timedelta(days=20),
)
inference_helper(
tmp_path,
- all_names,
+ in_names,
+ out_names,
use_prediction_data,
dim_sizes,
n_forward_steps,
@@ -164,21 +165,26 @@ def test_inference_plus_one_model(
def inference_helper(
tmp_path,
- all_names,
+ in_names,
+ out_names,
use_prediction_data,
dim_sizes: DimSizes,
n_forward_steps,
stepper_path,
timestep: datetime.timedelta,
save_monthly_files: bool = True,
+ derived_names: List[str] = [],
):
time_varying_values = [float(i) for i in range(dim_sizes.n_time)]
+ all_names = list(set(in_names).union(out_names))
+ forcing_names = list(set(in_names).difference(out_names))
data = FV3GFSData(
path=tmp_path,
names=all_names,
dim_sizes=dim_sizes,
time_varying_values=time_varying_values,
timestep_days=timestep.total_seconds() / 86400,
+ save_vertical_coordinate=False,
)
if use_prediction_data:
prediction_data = data.inference_data_loader_config
@@ -192,8 +198,7 @@ def inference_helper(
names=all_names,
dim_sizes=DimSizes(
n_time=48,
- n_lat=dim_sizes.n_lat,
- n_lon=dim_sizes.n_lon,
+ horizontal=dim_sizes.horizontal,
nz_interface=1,
),
n_ensemble=3,
@@ -228,18 +233,39 @@ def inference_helper(
yaml.dump(dataclasses.asdict(config), f)
with mock_wandb() as wandb:
- inference_logs = main(
+ wandb.configure(log_to_wandb=True)
+ main(
yaml_config=str(config_filename),
)
+ wandb_logs = wandb.get_logs()
+
+ all_out_names = out_names + derived_names
+
+ n_ic_timesteps = 1
+ summary_log_step = 1
+ assert len(wandb_logs) == n_ic_timesteps + config.n_forward_steps + summary_log_step
+ for i in range(n_ic_timesteps + config.n_forward_steps):
+ log = wandb_logs[i]
+ for var in all_out_names:
+ if i == 0 and var not in in_names:
+ assert f"inference/mean/weighted_rmse/{var}" not in log
+ else:
+ # if these are off by something like 90% then probably the stepper
+ # is being used instead of the prediction_data
+ assert log[f"inference/mean/weighted_rmse/{var}"] == 0.0
+ assert log[f"inference/mean/weighted_bias/{var}"] == 0.0
+
+ var = list(set(in_names).difference(forcing_names))[0]
+
+ if not use_prediction_data:
+ initial_condition_ds = xr.open_dataset(
+ tmp_path / "initial_condition.nc", decode_timedelta=False
+ )
+ for dim_name in ["lat", "lon"]:
+ assert dim_name in initial_condition_ds.dims
+ assert dim_name in initial_condition_ds.data_vars[var].dims
+ assert dim_name in initial_condition_ds.coords
- # Unlike the data writer outputs, aggregator logs include IC step
- assert len(inference_logs) == config.n_forward_steps + 1
- assert len(wandb.get_logs()) == len(inference_logs)
- for log in inference_logs:
- # if these are off by something like 90% then probably the stepper
- # is being used instead of the prediction_data
- assert log["inference/mean/weighted_rmse/var"] == 0.0
- assert log["inference/mean/weighted_bias/var"] == 0.0
prediction_ds = xr.open_dataset(
tmp_path / "autoregressive_predictions.nc",
decode_timedelta=False,
@@ -248,90 +274,103 @@ def inference_helper(
assert len(prediction_ds["time"]) == config.n_forward_steps
for i in range(config.n_forward_steps - 1):
np.testing.assert_allclose(
- prediction_ds["var"].isel(time=i).values + 1,
- prediction_ds["var"].isel(time=i + 1).values,
+ prediction_ds[var].isel(time=i).values + 1,
+ prediction_ds[var].isel(time=i + 1).values,
)
- assert not np.any(np.isnan(prediction_ds["var"].isel(time=i + 1).values))
+ assert not np.any(np.isnan(prediction_ds[var].isel(time=i + 1).values))
assert "lat" in prediction_ds.coords
assert "lon" in prediction_ds.coords
- restart_ds = xr.open_dataset(
- tmp_path / "restart.nc", decode_timedelta=False, decode_times=False
- )
- np.testing.assert_allclose(
- prediction_ds["var"].isel(time=-1).values,
- restart_ds["var"].values,
- )
+ if use_prediction_data:
+ assert not os.path.exists(tmp_path / "restart.nc")
+ assert not os.path.exists(tmp_path / "initial_condition.nc")
+ else:
+ restart_ds = xr.open_dataset(
+ tmp_path / "restart.nc", decode_timedelta=False, decode_times=False
+ )
+ np.testing.assert_allclose(
+ prediction_ds[var].isel(time=-1).values,
+ restart_ds[var].values,
+ )
- ic_ds = xr.open_dataset(
- tmp_path / "initial_condition.nc", decode_timedelta=False, decode_times=False
- )
- np.testing.assert_allclose(ic_ds["var"].values, 0.0)
+ ic_ds = xr.open_dataset(
+ tmp_path / "initial_condition.nc",
+ decode_timedelta=False,
+ decode_times=False,
+ )
+ np.testing.assert_allclose(ic_ds[var].values, 0.0)
metric_ds = xr.open_dataset(tmp_path / "reduced_autoregressive_predictions.nc")
- assert "var" in metric_ds.data_vars
- assert metric_ds.data_vars["var"].attrs["units"] == "m"
- assert metric_ds.data_vars["var"].attrs["long_name"] == "ensemble mean of var"
- assert "rmse_var" in metric_ds.data_vars
- assert metric_ds.data_vars["rmse_var"].attrs["units"] == "m"
+ assert var in metric_ds.data_vars
+ assert metric_ds.data_vars[var].attrs["units"] == "m"
+ assert metric_ds.data_vars[var].attrs["long_name"] == f"ensemble mean of {var}"
+ assert f"rmse_{var}" in metric_ds.data_vars
+ assert metric_ds.data_vars[f"rmse_{var}"].attrs["units"] == "m"
assert (
- metric_ds.data_vars["rmse_var"].attrs["long_name"]
- == "root mean squared error of var"
- )
- assert "bias_var" in metric_ds.data_vars
- assert metric_ds.data_vars["bias_var"].attrs["units"] == "m"
- assert "min_err_var" in metric_ds.data_vars
- assert metric_ds.data_vars["min_err_var"].attrs["units"] == "m"
- assert "max_err_var" in metric_ds.data_vars
- assert metric_ds.data_vars["max_err_var"].attrs["units"] == "m"
- assert "gen_var_var" in metric_ds.data_vars
- assert metric_ds.data_vars["gen_var_var"].attrs["units"] == ""
+ metric_ds.data_vars[f"rmse_{var}"].attrs["long_name"]
+ == f"root mean squared error of {var}"
+ )
+ assert f"bias_{var}" in metric_ds.data_vars
+ assert metric_ds.data_vars[f"bias_{var}"].attrs["units"] == "m"
+ assert f"min_err_{var}" in metric_ds.data_vars
+ assert metric_ds.data_vars[f"min_err_{var}"].attrs["units"] == "m"
+ assert f"max_err_{var}" in metric_ds.data_vars
+ assert metric_ds.data_vars[f"max_err_{var}"].attrs["units"] == "m"
+ assert f"gen_var_{var}" in metric_ds.data_vars
+ assert metric_ds.data_vars[f"gen_var_{var}"].attrs["units"] == ""
assert (
- metric_ds.data_vars["gen_var_var"].attrs["long_name"]
- == "prediction variance of var as fraction of target variance"
+ metric_ds.data_vars[f"gen_var_{var}"].attrs["long_name"]
+ == f"prediction variance of {var} as fraction of target variance"
)
assert "lat" in metric_ds.coords
assert "lon" in metric_ds.coords
time_mean_diagnostics = xr.open_dataset(tmp_path / "time_mean_diagnostics.nc")
actual_var_names = sorted([str(k) for k in time_mean_diagnostics.keys()])
- assert len(actual_var_names) == 2
- assert "bias_map-var" in actual_var_names
- assert time_mean_diagnostics.data_vars["bias_map-var"].attrs["units"] == "m"
- assert "gen_map-var" in actual_var_names
- assert time_mean_diagnostics.data_vars["gen_map-var"].attrs["units"] == "m"
+ assert len(actual_var_names) == 2 * len(all_out_names)
+ assert f"bias_map-{var}" in actual_var_names
+ assert time_mean_diagnostics.data_vars[f"bias_map-{var}"].attrs["units"] == "m"
+ assert f"gen_map-{var}" in actual_var_names
+ assert time_mean_diagnostics.data_vars[f"gen_map-{var}"].attrs["units"] == "m"
assert len(time_mean_diagnostics.coords) == 2
assert "lat" in time_mean_diagnostics.coords
assert "lon" in time_mean_diagnostics.coords
zonal_mean_diagnostics = xr.open_dataset(tmp_path / "zonal_mean_diagnostics.nc")
actual_var_names = sorted([str(k) for k in zonal_mean_diagnostics.keys()])
- assert len(actual_var_names) == 2
- assert "error-var" in actual_var_names
- assert zonal_mean_diagnostics.data_vars["error-var"].attrs["units"] == "m"
- assert "gen-var" in actual_var_names
- assert zonal_mean_diagnostics.data_vars["gen-var"].attrs["units"] == ""
+ assert len(actual_var_names) == 2 * len(all_out_names)
+ assert f"error-{var}" in actual_var_names
+ assert zonal_mean_diagnostics.data_vars[f"error-{var}"].attrs["units"] == "m"
+ assert f"gen-{var}" in actual_var_names
+ assert zonal_mean_diagnostics.data_vars[f"gen-{var}"].attrs["units"] == ""
assert len(zonal_mean_diagnostics.coords) == 1
assert "lat" in zonal_mean_diagnostics.coords
for source in ["target", "prediction"]:
histograms = xr.open_dataset(tmp_path / f"histograms_{source}.nc")
actual_var_names = sorted([str(k) for k in histograms.keys()])
- assert len(actual_var_names) == 2
- assert "var" in actual_var_names
- assert histograms.data_vars["var"].attrs["units"] == "count"
- assert "var_bin_edges" in actual_var_names
- assert histograms.data_vars["var_bin_edges"].attrs["units"] == "m"
- var_counts_per_timestep = histograms["var"].sum(dim=["bin"])
+ # NOTE: target histograms include forcing variables
+ n_vars = (
+ len(all_out_names)
+ if source == "prediction"
+ else len(all_out_names) + len(forcing_names)
+ )
+ assert len(actual_var_names) == 2 * n_vars
+ assert var in actual_var_names
+ assert histograms.data_vars[var].attrs["units"] == "count"
+ assert f"{var}_bin_edges" in actual_var_names
+ assert histograms.data_vars[f"{var}_bin_edges"].attrs["units"] == "m"
+ var_counts_per_timestep = histograms[var].sum(dim=["bin"])
same_count_each_timestep = np.all(
var_counts_per_timestep.values == var_counts_per_timestep.values[0]
)
assert same_count_each_timestep
if monthly_reference_filename is not None:
- assert "inference/annual/var" in inference_logs[-1]
- assert "inference/annual/r2_gen_var" in inference_logs[-1]
- assert "inference/annual/r2_target_var" in inference_logs[-1]
+ assert f"inference/annual/{var}" in wandb_logs[-1]
+ assert f"inference/annual/r2_gen_{var}" in wandb_logs[-1]
+ assert f"inference/annual/r2_target_{var}" in wandb_logs[-1]
+ assert "inference/total_steps_per_second" in wandb_logs[-1]
@pytest.mark.parametrize(
@@ -345,14 +384,21 @@ def test_inference_writer_boundaries(
out_names = ["var"]
all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper"
+
+ horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)]
+
dim_sizes = DimSizes(
n_time=n_forward_steps + 1,
- n_lat=4,
- n_lon=8,
+ horizontal=horizontal,
nz_interface=4,
)
save_plus_one_stepper(
- stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d
+ stepper_path,
+ in_names,
+ out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
)
data = FV3GFSData(
path=tmp_path,
@@ -376,12 +422,17 @@ def test_inference_writer_boundaries(
with open(config_filename, "w") as f:
yaml.dump(dataclasses.asdict(config), f)
with mock_wandb() as wandb:
- inference_logs = main(
+ wandb.configure(log_to_wandb=True)
+ main(
yaml_config=str(config_filename),
)
- # initial condition + n_forward_steps autoregressive steps
- assert len(inference_logs) == config.n_forward_steps + 1
- assert len(wandb.get_logs()) == len(inference_logs)
+ inference_logs = wandb.get_logs()
+ n_ic_timesteps = 1
+ summary_log_step = 1
+ assert (
+ len(inference_logs)
+ == n_ic_timesteps + config.n_forward_steps + summary_log_step
+ )
prediction_ds = xr.open_dataset(
tmp_path / "autoregressive_predictions.nc", decode_timedelta=False
@@ -401,7 +452,6 @@ def test_inference_writer_boundaries(
tar["lat"].values, num_lon=len(tar["lon"])
)
# check time mean metrics
- assert inference_logs[-1]["inference/mean/forecast_step"] == n_forward_steps
tol = 1e-4 # relative tolerance
assert metrics.root_mean_squared_error(
tar_time_mean, gen_time_mean, area_weights
@@ -416,13 +466,13 @@ def test_inference_writer_boundaries(
prediction_ds = prediction_ds.isel(sample=0)
target_ds = target_ds.isel(sample=0)
- ds = xr.open_dataset(data._data_filename)
+ ds = xr.open_dataset(data.data_filename)
for i in range(0, n_forward_steps):
# metrics logs includes IC while saved data does not
- log = inference_logs[i + 1]
+ log = inference_logs[i + n_ic_timesteps]
# metric steps should match lead times
- assert log["inference/mean/forecast_step"] == i + 1
+ assert log["inference/mean/forecast_step"] == i + n_ic_timesteps
gen_i = torch.from_numpy(gen.isel(time=i).values)
tar_i = torch.from_numpy(tar.isel(time=i).values)
# check that manually computed metrics match logged metrics
@@ -469,14 +519,21 @@ def test_inference_data_time_coarsening(tmp_path: pathlib.Path):
out_names = ["var"]
all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper"
+
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
+
dim_sizes = DimSizes(
n_time=9,
- n_lat=16,
- n_lon=32,
+ horizontal=horizontal,
nz_interface=4,
)
save_plus_one_stepper(
- stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d
+ stepper_path,
+ in_names,
+ out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
)
data = FV3GFSData(
path=tmp_path,
@@ -542,27 +599,33 @@ def _make_data():
for var in vars
}
- loss = 42.0
- fake_data = {
- k: _make_data()
- for k in ("gen_data", "target_data", "gen_data_norm", "target_data_norm")
- }
- stepped = SteppedData(
- loss,
- fake_data["gen_data"],
- fake_data["target_data"],
- fake_data["gen_data_norm"],
- fake_data["target_data_norm"],
- )
-
- sigma_coords = SigmaCoordinates(
+ vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.linspace(0, 1, nz + 1, device=get_device()),
bk=torch.linspace(0, 1, nz + 1, device=get_device()),
)
- derived_stepped = compute_stepped_derived_quantities(
- stepped, sigma_coords, TIMESTEP
+
+ def derive_func(data: TensorMapping, forcing_data: TensorMapping) -> TensorDict:
+ updated = compute_derived_quantities(
+ dict(data),
+ vertical_coordinate=vertical_coordinate,
+ timestep=TIMESTEP,
+ forcing_data=dict(forcing_data),
+ )
+ return updated
+
+ metrics = {"loss": 42.0}
+ fake_data = {k: _make_data() for k in ("gen_data", "target_data")}
+ stepped = TrainOutput(
+ metrics,
+ fake_data["gen_data"],
+ fake_data["target_data"],
+ time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]),
+ normalize=lambda x: x,
+ derive_func=derive_func,
)
+ derived_stepped = stepped.compute_derived_variables()
+
dry_air_name = "surface_pressure_due_to_dry_air"
water_path_name = "total_water_path"
existence_check = (
@@ -595,14 +658,22 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path):
out_names = ["var", "PRESsfc", "specific_total_water_0", "specific_total_water_1"]
all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper"
+
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
+
dim_sizes = DimSizes(
n_time=n_forward_steps + 1,
- n_lat=16,
- n_lon=32,
- nz_interface=4,
+ horizontal=horizontal,
+ nz_interface=3,
)
save_plus_one_stepper(
- stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d
+ stepper_path,
+ in_names,
+ out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
+ nz_interface=dim_sizes.nz_interface,
)
time_varying_values = [float(i) for i in range(dim_sizes.n_time)]
data = FV3GFSData(
@@ -611,6 +682,7 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path):
dim_sizes=dim_sizes,
time_varying_values=time_varying_values,
timestep_days=TIMESTEP.total_seconds() / 86400,
+ num_data_workers=2,
)
config = InferenceEvaluatorConfig(
experiment_dir=str(tmp_path),
@@ -630,10 +702,14 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path):
with open(config_filename, "w") as f:
yaml.dump(dataclasses.asdict(config), f)
- with mock_wandb() as _:
- _ = main(
- yaml_config=str(config_filename),
- )
+ with mock_wandb() as wandb:
+ wandb.configure(log_to_wandb=True)
+ main(yaml_config=str(config_filename))
+ inference_logs = wandb.get_logs()
+
+ # derived variables should not have normalized metrics reported
+ assert "inference/mean_norm/weighted_rmse/total_water_path" not in inference_logs[0]
+ assert "inference/time_mean_norm/rmse/total_water_path" not in inference_logs[-1]
@pytest.mark.parametrize(
@@ -678,14 +754,20 @@ def test_inference_ocean_override(tmp_path: pathlib.Path):
all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper"
n_forward_steps = 8
+
+ horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)]
dim_sizes = DimSizes(
n_time=n_forward_steps + 1,
- n_lat=4,
- n_lon=8,
+ horizontal=horizontal,
nz_interface=4,
)
save_plus_one_stepper(
- stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d
+ stepper_path,
+ in_names,
+ out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
)
data = FV3GFSData(
path=tmp_path,
@@ -708,10 +790,7 @@ def test_inference_ocean_override(tmp_path: pathlib.Path):
forward_steps_in_memory=4,
ocean=ocean_override,
)
- stepper = config.load_stepper(
- sigma_coordinates=SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)),
- area=torch.ones(10),
- )
+ stepper = config.load_stepper()
assert isinstance(stepper.ocean, Ocean)
assert (
stepper.ocean.surface_temperature_name
@@ -719,6 +798,9 @@ def test_inference_ocean_override(tmp_path: pathlib.Path):
)
assert stepper.ocean.ocean_fraction_name == ocean_override.ocean_fraction_name
+ stepper_config = config.load_stepper_config()
+ assert stepper_config.ocean == ocean_override
+
def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path):
"""Test that inference with a model trained with a different timestep than
@@ -726,16 +808,18 @@ def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path):
"""
in_names = ["var"]
out_names = ["var"]
- all_names = list(set(in_names).union(out_names))
stepper_path = tmp_path / "stepper_test_data"
- dim_sizes = DimSizes(n_time=8, n_lat=4, n_lon=8, nz_interface=2)
+
+ horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)]
+ dim_sizes = DimSizes(n_time=8, horizontal=horizontal, nz_interface=2)
std = 1.0
save_plus_one_stepper(
stepper_path,
- names=all_names,
+ in_names,
+ out_names,
mean=0.0,
std=std,
- data_shape=dim_sizes.shape_2d,
+ data_shape=dim_sizes.shape_nd,
timestep=TIMESTEP,
)
use_prediction_data = False
@@ -743,10 +827,70 @@ def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path):
with pytest.raises(ValueError, match="Timestep of the loaded stepper"):
inference_helper(
tmp_path,
- all_names,
+ in_names,
+ out_names,
use_prediction_data,
dim_sizes,
n_forward_steps,
stepper_path,
timestep=datetime.timedelta(days=20),
)
+
+
+def test_inference_includes_diagnostics(tmp_path: pathlib.Path):
+ """Test that diagnostics are included in evaluator metrics and outputs."""
+ # NOTE: size of in_names and out_names has to be the same here or the
+ # PlusOne outputs won't have the right shape
+ in_names = ["prog", "forcing_var", "DSWRFtoa"]
+ out_names = ["prog", "ULWRFtoa", "USWRFtoa"]
+ stepper_path = tmp_path / "stepper"
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
+ use_prediction_data = False
+ n_forward_steps = 2
+ dim_sizes = DimSizes(
+ n_time=n_forward_steps + 1,
+ horizontal=horizontal,
+ nz_interface=4,
+ )
+ save_plus_one_stepper(
+ stepper_path,
+ in_names,
+ out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
+ timestep=datetime.timedelta(days=20),
+ )
+ inference_helper(
+ tmp_path,
+ in_names,
+ out_names,
+ use_prediction_data,
+ dim_sizes,
+ n_forward_steps,
+ stepper_path,
+ save_monthly_files=False, # requires timestep == 6h
+ timestep=datetime.timedelta(days=20),
+ derived_names=["net_energy_flux_toa_into_atmosphere"],
+ )
+ ds = xr.open_dataset(
+ tmp_path / "autoregressive_predictions.nc",
+ decode_timedelta=False,
+ decode_times=False,
+ )
+ # prognostic in
+ assert "prog" in ds
+ # diags in
+ assert "ULWRFtoa" in ds
+ assert "USWRFtoa" in ds
+ # derived in
+ assert "net_energy_flux_toa_into_atmosphere" in ds
+ # forcings not in
+ assert "DSWRFtoa" not in ds
+ assert "forcing_var" not in ds
+ # assert only prognostic variables are in initial condition and restart files
+ for filename in ["initial_condition.nc", "restart.nc"]:
+ ds = xr.open_dataset(tmp_path / filename)
+ assert "USWRFtoa" not in ds
+ assert "forcing_var" not in ds
+ assert "prog" in ds
diff --git a/fme/fme/ace/inference/test_inference.py b/fme/fme/ace/inference/test_inference.py
index 66c4bc1..cfabd18 100644
--- a/fme/fme/ace/inference/test_inference.py
+++ b/fme/fme/ace/inference/test_inference.py
@@ -3,7 +3,7 @@
import dataclasses
import datetime
import pathlib
-from typing import List, Tuple
+from typing import List
import cftime
import numpy as np
@@ -13,6 +13,13 @@
import yaml
import fme
+from fme.ace.data_loading.batch_data import PrognosticState
+from fme.ace.data_loading.inference import (
+ ExplicitIndices,
+ ForcingDataLoaderConfig,
+ InferenceInitialConditionIndices,
+ TimestampList,
+)
from fme.ace.inference.data_writer import DataWriterConfig
from fme.ace.inference.inference import (
InferenceConfig,
@@ -21,17 +28,17 @@
main,
)
from fme.ace.registry import ModuleSelector
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.data_loading.inference import (
- ExplicitIndices,
- ForcingDataLoaderConfig,
- InferenceInitialConditionIndices,
- TimestampList,
+from fme.ace.stepper import SingleModuleStepperConfig
+from fme.ace.testing import DimSizes, FV3GFSData
+from fme.core.coordinates import (
+ DimSize,
+ HybridSigmaPressureCoordinate,
+ LatLonCoordinates,
)
+from fme.core.gridded_ops import LatLonOperations
from fme.core.logging_utils import LoggingConfig
-from fme.core.normalizer import FromStateNormalizer
-from fme.core.stepper import SingleModuleStepperConfig
-from fme.core.testing import DimSizes, FV3GFSData
+from fme.core.normalizer import NormalizationConfig
+from fme.core.testing import mock_wandb
TIMESTEP = datetime.timedelta(hours=6)
@@ -47,7 +54,7 @@ def save_stepper(
out_names: List[str],
mean: float,
std: float,
- data_shape: Tuple[int, int, int],
+ data_shape: List[int],
timestep: datetime.timedelta = TIMESTEP,
):
all_names = list(set(in_names).union(out_names))
@@ -55,19 +62,19 @@ def save_stepper(
builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}),
in_names=in_names,
out_names=out_names,
- normalization=FromStateNormalizer(
- state={
- "means": {name: mean for name in all_names},
- "stds": {name: std for name in all_names},
- }
+ normalization=NormalizationConfig(
+ means={name: mean for name in all_names},
+ stds={name: std for name in all_names},
),
)
area = torch.ones(data_shape[-2:], device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
stepper = config.get_stepper(
- img_shape=data_shape[-2:],
- area=area,
- sigma_coordinates=sigma_coordinates,
+ img_shape=(data_shape[-2], data_shape[-1]),
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=timestep,
)
torch.save({"stepper": stepper.get_state()}, path)
@@ -75,13 +82,16 @@ def save_stepper(
def test_inference_entrypoint(tmp_path: pathlib.Path):
forward_steps_in_memory = 2
- in_names = ["prog", "forcing_var"]
- out_names = ["prog", "diagnostic_var"]
+ # NOTE: number of inputs and outputs has to be the same for the PlusOne
+ # stepper module to work properly
+ in_names = ["prog", "forcing_var", "DSWRFtoa"]
+ out_names = ["prog", "ULWRFtoa", "USWRFtoa"]
stepper_path = tmp_path / "stepper"
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
+
dim_sizes = DimSizes(
n_time=9,
- n_lat=16,
- n_lon=32,
+ horizontal=horizontal,
nz_interface=4,
)
save_stepper(
@@ -90,23 +100,37 @@ def test_inference_entrypoint(tmp_path: pathlib.Path):
out_names=out_names,
mean=0.0,
std=1.0,
- data_shape=dim_sizes.shape_2d,
+ data_shape=dim_sizes.shape_nd,
)
data = FV3GFSData(
path=tmp_path,
- names=["forcing_var"],
+ names=["forcing_var", "DSWRFtoa"],
dim_sizes=dim_sizes,
timestep_days=0.25,
+ save_vertical_coordinate=False,
)
initial_condition = xr.Dataset(
- {"prog": xr.DataArray(np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"])}
+ {
+ "prog": xr.DataArray(
+ np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"]
+ ),
+ "forcing": xr.DataArray(
+ np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"]
+ ),
+ "DSWRFtoa": xr.DataArray(
+ np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"]
+ ),
+ }
)
initial_condition_path = tmp_path / "init_data" / "ic.nc"
initial_condition_path.parent.mkdir()
initial_condition["time"] = xr.DataArray(
- [cftime.datetime(2000, 1, 1, 6), cftime.datetime(2000, 1, 1, 18)],
- dims=["sample"],
+ [
+ cftime.DatetimeProlepticGregorian(2000, 1, 1, 6),
+ cftime.DatetimeProlepticGregorian(2000, 1, 1, 18),
+ ],
+ dims=["time"],
)
initial_condition.to_netcdf(initial_condition_path, mode="w")
forcing_loader = ForcingDataLoaderConfig(
@@ -122,7 +146,7 @@ def test_inference_entrypoint(tmp_path: pathlib.Path):
logging=LoggingConfig(
log_to_screen=True,
log_to_file=False,
- log_to_wandb=False,
+ log_to_wandb=True,
),
initial_condition=InitialConditionConfig(path=str(initial_condition_path)),
forcing_loader=forcing_loader,
@@ -131,9 +155,39 @@ def test_inference_entrypoint(tmp_path: pathlib.Path):
config_filename = tmp_path / "config.yaml"
with open(config_filename, "w") as f:
yaml.dump(dataclasses.asdict(config), f)
- main(yaml_config=str(config_filename))
+
+ with mock_wandb() as wandb:
+ wandb.configure(log_to_wandb=True)
+ main(yaml_config=str(config_filename))
+ wandb_logs = wandb.get_logs()
+
+ n_ic_timesteps = 1
+ summary_log_step = 1
+ assert len(wandb_logs) == n_ic_timesteps + config.n_forward_steps + summary_log_step
+ for i, log in enumerate(wandb_logs):
+ for metric, val in log.items():
+ # check that time series metrics match
+ if "inference/mean" in metric:
+ if i > 0:
+ assert metric in wandb_logs[i]
+ if np.isnan(val):
+ assert np.isnan(wandb_logs[i][metric])
+ else:
+ assert wandb_logs[i][metric] == val
+ elif not np.isnan(val): # for IC only valid data is reported to wandb
+ assert metric in wandb_logs[i]
+ assert wandb_logs[i][metric] == val
+
ds = xr.open_dataset(tmp_path / "autoregressive_predictions.nc")
+ # prognostic in
assert "prog" in ds
+ # diags in
+ assert "ULWRFtoa" in ds
+ assert "USWRFtoa" in ds
+ # derived in
+ assert "net_energy_flux_toa_into_atmosphere" in ds
+ # forcings not in
+ assert "DSWRFtoa" not in ds
assert "forcing_var" not in ds
assert ds["prog"].sizes == {"time": 4, "sample": 2, "lat": 16, "lon": 32}
np.testing.assert_allclose(
@@ -142,6 +196,26 @@ def test_inference_entrypoint(tmp_path: pathlib.Path):
np.testing.assert_allclose(
ds["prog"].isel(time=1).values, ds["prog"].isel(time=0).values + 1, rtol=1e-6
)
+ saved_data = xr.open_dataset(data.data_filename)
+ ops = LatLonCoordinates(
+ lat=torch.as_tensor(saved_data["grid_yt"].values),
+ lon=torch.as_tensor(saved_data["grid_xt"].values),
+ ).gridded_operations
+ # check that inference logs match raw output
+ for i in range(1, config.n_forward_steps + 1):
+ for log_name in wandb_logs[i]:
+ if "inference/mean/weighted_mean_gen" in log_name:
+ variable_name = log_name.split("/")[-1]
+ # note raw output does not include initial condition, hence
+ # i-1 below. Code uses area from data, not stepper above.
+ raw_variable = ds[variable_name].isel(time=i - 1)
+ raw_global_mean = ops.area_weighted_mean(
+ torch.as_tensor(raw_variable.values)
+ ).mean()
+ np.testing.assert_allclose(
+ raw_global_mean, wandb_logs[i][log_name], rtol=1e-6
+ )
+ assert "inference/total_steps_per_second" in wandb_logs[-1]
def test_get_initial_condition():
@@ -150,13 +224,19 @@ def test_get_initial_condition():
np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"]
)
data = xr.Dataset({"prog": prognostic_da, "time": time_da})
- initial_condition, initial_times = get_initial_condition(data, ["prog"])
+ initial_condition = get_initial_condition(data, ["prog"])
+ assert isinstance(initial_condition, PrognosticState)
+ batch_data = initial_condition.as_batch_data()
+ assert batch_data.time.shape == (2, 1)
+ initial_times = batch_data.time.isel(time=0)
assert initial_times.shape == (2,)
assert initial_times[0] == 0
assert initial_times[1] == 5
- assert initial_condition["prog"].shape == (2, 16, 32)
- np.testing.assert_allclose(initial_condition["prog"].numpy(), data["prog"].values)
- assert initial_condition["prog"].device == fme.get_device()
+ assert batch_data.data["prog"].shape == (2, 1, 16, 32)
+ np.testing.assert_allclose(
+ batch_data.data["prog"].squeeze(dim=1).cpu().numpy(), data["prog"].values
+ )
+ assert batch_data.time.isel(time=0).equals(initial_times)
def test_get_initial_condition_raises_bad_variable_shape():
diff --git a/fme/fme/ace/inference/test_segmented.py b/fme/fme/ace/inference/test_segmented.py
new file mode 100644
index 0000000..bc552d7
--- /dev/null
+++ b/fme/fme/ace/inference/test_segmented.py
@@ -0,0 +1,244 @@
+"""Tests for segmented inference entrypoint."""
+
+import dataclasses
+import datetime
+import os
+import pathlib
+import tempfile
+import unittest.mock
+from typing import List
+
+import cftime
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+import yaml
+
+import fme
+from fme.ace.data_loading.inference import ForcingDataLoaderConfig, TimestampList
+from fme.ace.inference.data_writer import DataWriterConfig
+from fme.ace.inference.inference import (
+ InitialConditionConfig,
+ main,
+ run_segmented_inference,
+)
+from fme.ace.registry import ModuleSelector
+from fme.ace.stepper import SingleModuleStepperConfig
+from fme.ace.testing import DimSizes, FV3GFSData
+from fme.core.coordinates import DimSize, HybridSigmaPressureCoordinate
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.gridded_ops import LatLonOperations
+from fme.core.logging_utils import LoggingConfig
+from fme.core.normalizer import NormalizationConfig
+
+TIMESTEP = datetime.timedelta(hours=6)
+
+
+class PlusOne(torch.nn.Module):
+ def forward(self, x):
+ return x + 1
+
+
+def save_stepper(
+ path: pathlib.Path,
+ in_names: List[str],
+ out_names: List[str],
+ mean: float,
+ std: float,
+ data_shape: List[int],
+ timestep: datetime.timedelta = TIMESTEP,
+):
+ all_names = list(set(in_names).union(out_names))
+ config = SingleModuleStepperConfig(
+ builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}),
+ in_names=in_names,
+ out_names=out_names,
+ normalization=NormalizationConfig(
+ means={name: mean for name in all_names},
+ stds={name: std for name in all_names},
+ ),
+ )
+ area = torch.ones(data_shape[-2:], device=fme.get_device())
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
+ stepper = config.get_stepper(
+ img_shape=(data_shape[-2], data_shape[-1]),
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ )
+ torch.save({"stepper": stepper.get_state()}, path)
+
+
+def test_inference_segmented_entrypoint():
+ # we use tempfile here instead of pytest tmp_path fixture, because the latter causes
+ # issues with checking last modified time of files produced by the test.
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_path = pathlib.Path(tmp_dir)
+ forward_steps_in_memory = 2
+ in_names = ["prog", "forcing_var"]
+ out_names = ["prog", "diagnostic_var"]
+ stepper_path = tmp_path / "stepper"
+ horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)]
+
+ dim_sizes = DimSizes(
+ n_time=18,
+ horizontal=horizontal,
+ nz_interface=4,
+ )
+ save_stepper(
+ stepper_path,
+ in_names=in_names,
+ out_names=out_names,
+ mean=0.0,
+ std=1.0,
+ data_shape=dim_sizes.shape_nd,
+ )
+ data = FV3GFSData(
+ path=tmp_path,
+ names=["forcing_var"],
+ dim_sizes=dim_sizes,
+ timestep_days=0.25,
+ )
+ initial_condition = xr.Dataset(
+ {
+ "prog": xr.DataArray(
+ np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"]
+ )
+ }
+ )
+
+ initial_condition_path = tmp_path / "init_data" / "ic.nc"
+ initial_condition_path.parent.mkdir()
+ initial_condition["time"] = xr.DataArray(
+ [cftime.datetime(2000, 1, 1, 6), cftime.datetime(2000, 1, 1, 18)],
+ dims=["sample"],
+ )
+ initial_condition.to_netcdf(initial_condition_path, mode="w")
+ forcing_loader = ForcingDataLoaderConfig(
+ dataset=data.inference_data_loader_config.dataset,
+ num_data_workers=0,
+ )
+
+ run_dir = tmp_path / "segmented_run"
+ config = fme.ace.InferenceConfig(
+ experiment_dir=str(run_dir),
+ n_forward_steps=3,
+ forward_steps_in_memory=forward_steps_in_memory,
+ checkpoint_path=str(stepper_path),
+ logging=LoggingConfig(
+ log_to_screen=True, log_to_file=False, log_to_wandb=False
+ ),
+ initial_condition=InitialConditionConfig(
+ path=str(initial_condition_path),
+ start_indices=TimestampList(["2000-01-01T06:00:00"]),
+ ),
+ forcing_loader=forcing_loader,
+ data_writer=DataWriterConfig(save_prediction_files=True),
+ )
+
+ # run one segment of 3 steps
+ config_path = str(tmp_path / "config.yaml")
+ with open(config_path, "w") as f:
+ yaml.dump(dataclasses.asdict(config), f)
+ main(config_path, 1)
+
+ # run another segment of 3 steps, and ensure first segment is not being re-run
+ filename = os.path.join(
+ run_dir, "segment_0000", "autoregressive_predictions.nc"
+ )
+ before_second_segment_mtime = os.path.getmtime(filename)
+ main(config_path, 2)
+ after_second_segment_mtime = os.path.getmtime(filename)
+ assert before_second_segment_mtime == pytest.approx(after_second_segment_mtime)
+
+ # do a non-segmented run of 6 steps
+ config.n_forward_steps = 6
+ config.experiment_dir = str(tmp_path / "non_segmented_run")
+ config_path = str(tmp_path / "config.yaml")
+ with open(config_path, "w") as f:
+ yaml.dump(dataclasses.asdict(config), f)
+ main(config_path)
+
+ # assert each segment generated output of correct duration
+ ds_two_segments_0 = xr.open_dataset(
+ run_dir / "segment_0000" / "autoregressive_predictions.nc"
+ )
+ ds_two_segments_1 = xr.open_dataset(
+ run_dir / "segment_0001" / "autoregressive_predictions.nc"
+ )
+ assert len(ds_two_segments_0.time) == len(ds_two_segments_1.time)
+
+ # Ensure the second half of the 6-step run matches the second segment of the
+ # 3-step run. Before comparing, drop init_time and time coordinates, since
+ # we don't expect these to match.
+ ds_one_segment = xr.open_dataset(
+ tmp_path / "non_segmented_run" / "autoregressive_predictions.nc"
+ )
+ ds_two_segments_1 = ds_two_segments_1.drop_vars(["init_time", "time"])
+ ds_one_segment = ds_one_segment.drop_vars(["init_time", "time"])
+ xr.testing.assert_identical(
+ ds_two_segments_1, ds_one_segment.isel(time=slice(3, None))
+ )
+
+
+def _run_inference_from_config_mock(config: fme.ace.InferenceConfig):
+ if not os.path.exists(config.experiment_dir):
+ os.makedirs(config.experiment_dir)
+ with open(os.path.join(config.experiment_dir, "restart.nc"), "w") as f:
+ f.write("mock restart file")
+ with open(os.path.join(config.experiment_dir, "wandb_name_env_var"), "w") as f:
+ f.write(os.environ.get("WANDB_NAME", ""))
+
+
+def _get_mock_config(experiment_dir: str) -> fme.ace.InferenceConfig:
+ return fme.ace.InferenceConfig(
+ experiment_dir=experiment_dir,
+ n_forward_steps=3,
+ checkpoint_path="mock_checkpoint",
+ logging=LoggingConfig(
+ log_to_screen=True, log_to_file=False, log_to_wandb=False
+ ),
+ initial_condition=InitialConditionConfig(path="mock_ic"),
+ forcing_loader=ForcingDataLoaderConfig(
+ dataset=XarrayDataConfig(data_path="mock_forcing")
+ ),
+ )
+
+
+def test_run_segmented_inference(tmp_path, monkeypatch):
+ WRITTEN_WANDB_NAME_FILENAME = "wandb_name_env_var"
+ mock = unittest.mock.MagicMock(side_effect=_run_inference_from_config_mock)
+ config = _get_mock_config(str(tmp_path))
+
+ with unittest.mock.patch(
+ "fme.ace.inference.inference.run_inference_from_config", new=mock
+ ):
+ # run a single segment
+ monkeypatch.setenv("WANDB_NAME", "run_name")
+ run_segmented_inference(config, 1)
+ segment_dir = os.path.join(config.experiment_dir, "segment_0000")
+ expected_restart_path = os.path.join(segment_dir, "restart.nc")
+ assert os.path.exists(expected_restart_path)
+ assert mock.call_count == 1
+ with open(os.path.join(segment_dir, WRITTEN_WANDB_NAME_FILENAME)) as f:
+ assert f.read() == "run_name-segment_0000"
+
+ # rerun the same segment and ensure run_inference_from_config isn't called again
+ run_segmented_inference(config, 1)
+ assert os.path.exists(expected_restart_path)
+ assert mock.call_count == 1
+
+ # extend to three segments and ensure exactly three run_inference_from_config
+ # calls have been made
+ monkeypatch.setenv("WANDB_NAME", "run_name")
+ run_segmented_inference(config, 3)
+ for i in range(3):
+ segment_dir = os.path.join(config.experiment_dir, f"segment_{i:04d}")
+ expected_restart_path = os.path.join(segment_dir, "restart.nc")
+ assert os.path.exists(expected_restart_path)
+ with open(os.path.join(segment_dir, WRITTEN_WANDB_NAME_FILENAME)) as f:
+ assert f.read() == f"run_name-segment_{i:04d}"
+ assert mock.call_count == 3
diff --git a/fme/fme/ace/models/healpix/__init__.py b/fme/fme/ace/models/healpix/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/fme/fme/ace/models/healpix/healpix_activations.py b/fme/fme/ace/models/healpix/healpix_activations.py
new file mode 100644
index 0000000..336cc7a
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_activations.py
@@ -0,0 +1,206 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from typing import Literal
+
+import torch as th
+import torch.nn as nn
+
+from .healpix_layers import HEALPixLayer
+
+# DOWNSAMPLING BLOCKS
+
+
+class MaxPool(nn.Module):
+ """Wrapper for applying Max Pooling with HEALPix or other tensor data.
+
+ This class wraps the `nn.MaxPool2d` class to handle tensor data with
+ HEALPix or other geometry layers.
+ """
+
+ def __init__(
+ self,
+ pooling: int = 2,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Args:
+ pooling (int, optional): Pooling kernel size passed to geometry layer.
+ enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper.
+ enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper.
+ """
+ super().__init__()
+ self.maxpool = HEALPixLayer(
+ layer=nn.MaxPool2d,
+ kernel_size=pooling,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ """Forward pass of the MaxPool.
+
+ Args:
+ x: The values to MaxPool.
+
+ Returns:
+ The MaxPooled values.
+ """
+ return self.maxpool(x)
+
+
+class AvgPool(nn.Module):
+ """Wrapper for applying Average Pooling with HEALPix or other tensor data.
+
+ This class wraps the `nn.AvgPool2d` class to handle tensor data with
+ HEALPix or other geometry layers.
+ """
+
+ def __init__(
+ self,
+ pooling: int = 2,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Args:
+ pooling (int, optional): Pooling kernel size passed to geometry layer.
+ enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper.
+ enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper.
+ """
+ super().__init__()
+ self.avgpool = HEALPixLayer(
+ layer=nn.AvgPool2d,
+ kernel_size=pooling,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ """Forward pass of the AvgPool layer.
+
+ Args:
+ x: The values to average.
+
+ Returns:
+ The averaged values.
+ """
+ return self.avgpool(x)
+
+
+@dataclasses.dataclass
+class DownsamplingBlockConfig:
+ """
+ Configuration for the downsampling block.
+ Generally, either a pooling block or a striding conv block.
+
+ Parameters:
+ block_type: Type of recurrent block, either "MaxPool" or "AvgPool"
+ pooling: Pooling size
+ enable_nhwc: Flag to enable NHWC data format, default is False.
+ enable_healpixpad: Flag to enable HEALPix padding, default is False.
+
+ """
+
+ block_type: Literal["MaxPool", "AvgPool"]
+ pooling: int = 2
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+
+ def build(self) -> nn.Module:
+ """
+ Builds the recurrent block model.
+
+ Returns:
+ Recurrent block.
+ """
+ if self.block_type == "MaxPool":
+ return MaxPool(
+ pooling=self.pooling,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+
+ elif self.block_type == "AvgPool":
+ return AvgPool(
+ pooling=self.pooling,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ else:
+ raise ValueError(f"Unsupported block type: {self.block_type}")
+
+
+@dataclasses.dataclass
+class CappedGELUConfig:
+ """
+ Configuration for the CappedGELU activation function.
+
+ Parameters:
+ cap_value: Cap value for the GELU function, default is 10.
+ enable_nhwc: Flag to enable NHWC data format, default is False.
+ enable_healpixpad: Flag to enable HEALPix padding, default is False.
+ """
+
+ cap_value: int = 10
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+
+ def build(self) -> nn.Module:
+ """
+ Builds the CappedGELU activation function.
+
+ Returns:
+ CappedGELU activation function.
+ """
+ return CappedGELU(cap_value=self.cap_value)
+
+
+class CappedGELU(nn.Module):
+ """
+ Implements a GELU with capped maximum value.
+
+ Example
+ -------
+ >>> capped_gelu_func = modulus.models.layers.CappedGELU()
+ >>> input = th.Tensor([[-2,-1],[0,1],[2,3]])
+ >>> capped_gelu_func(input)
+ tensor([[-0.0455, -0.1587],
+ [ 0.0000, 0.8413],
+ [ 1.0000, 1.0000]])
+
+ """
+
+ def __init__(self, cap_value=1.0, **kwargs):
+ """
+ Args:
+ cap_value: Maximum that values will be capped at
+ **kwargs: Keyword arguments to be passed to the `th.nn.GELU` function
+ """
+
+ super().__init__()
+ self.add_module("gelu", th.nn.GELU(**kwargs))
+ self.register_buffer("cap", th.tensor(cap_value, dtype=th.float32))
+
+ def forward(self, inputs):
+ x = self.gelu(inputs)
+ # Convert cap to a scalar value for clamping (ignores grad)
+ cap_value = self.cap.item()
+ x = th.clamp(x, max=cap_value)
+ return x
diff --git a/fme/fme/ace/models/healpix/healpix_blocks.py b/fme/fme/ace/models/healpix/healpix_blocks.py
new file mode 100644
index 0000000..e1c667b
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_blocks.py
@@ -0,0 +1,926 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import dataclasses
+from typing import Literal, Optional, Tuple, Union, cast
+
+import torch as th
+import torch.nn as nn
+
+from .healpix_activations import CappedGELUConfig
+from .healpix_layers import HEALPixLayer
+
+# RECURRENT BLOCKS
+
+
+@dataclasses.dataclass
+class RecurrentBlockConfig:
+ """
+ Configuration for the recurrent block.
+
+ Parameters:
+ in_channels: Number of input channels, default is 3.
+ kernel_size: Size of the kernel, default is 1.
+ enable_nhwc: Flag to enable NHWC data format, default is False.
+ enable_healpixpad: Flag to enable HEALPix padding, default is False.
+ block_type: Type of recurrent block, either "ConvGRUBlock" or "ConvLSTMBlock",
+ default is "ConvGRUBlock".
+ """
+
+ in_channels: int = 3
+ kernel_size: int = 1
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+ block_type: Literal["ConvGRUBlock", "ConvLSTMBlock"] = "ConvGRUBlock"
+
+ def build(self) -> nn.Module:
+ """
+ Builds the recurrent block model.
+
+ Returns:
+ Recurrent block.
+ """
+ if self.block_type == "ConvGRUBlock":
+ return ConvGRUBlock(
+ in_channels=self.in_channels,
+ kernel_size=self.kernel_size,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ elif self.block_type == "ConvLSTMBlock":
+ return ConvLSTMBlock(
+ in_channels=self.in_channels,
+ kernel_size=self.kernel_size,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ else:
+ raise ValueError(f"Unsupported block type: {self.block_type}")
+
+
+@dataclasses.dataclass
+class ConvBlockConfig:
+ """
+ Configuration for the convolutional block.
+
+ Parameters:
+ in_channels: Number of input channels, default is 3.
+ out_channels: Number of output channels, default is 1.
+ kernel_size: Size of the kernel, default is 3.
+ dilation: Dilation rate, default is 1.
+ n_layers: Number of layers, default is 1.
+ upsampling: Upsampling factor for TransposedConvUpsample, default is 2.
+ upscale_factor: Upscale factor for ConvNeXtBlock and SymmetricConvNeXtBlock,
+ default is 4.
+ latent_channels: Number of latent channels, default is None.
+ activation: Activation configuration, default is None.
+ enable_nhwc: Flag to enable NHWC data format, default is False.
+ enable_healpixpad: Flag to enable HEALPix padding, default is False.
+ block_type: Type of block, default is "BasicConvBlock".
+ """
+
+ in_channels: int = 3
+ out_channels: int = 1
+ kernel_size: int = 3
+ dilation: int = 1
+ n_layers: int = 1
+ upsampling: int = 2
+ upscale_factor: int = 4
+ latent_channels: Optional[int] = None
+ activation: Optional[CappedGELUConfig] = None
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+ block_type: Literal[
+ "BasicConvBlock",
+ "ConvNeXtBlock",
+ "SymmetricConvNeXtBlock",
+ "TransposedConvUpsample",
+ ] = "BasicConvBlock"
+
+ def build(self) -> nn.Module:
+ """
+ Builds the convolutional block model.
+
+ Returns:
+ Convolutional block model.
+ """
+ if self.block_type == "BasicConvBlock":
+ return BasicConvBlock(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=self.kernel_size,
+ dilation=self.dilation,
+ n_layers=self.n_layers,
+ latent_channels=self.latent_channels,
+ activation=self.activation,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ elif self.block_type == "ConvNeXtBlock":
+ if self.latent_channels is None:
+ self.latent_channels = 1
+ return ConvNeXtBlock(
+ in_channels=self.in_channels,
+ latent_channels=cast(int, self.latent_channels),
+ out_channels=self.out_channels,
+ kernel_size=self.kernel_size,
+ dilation=self.dilation,
+ upscale_factor=self.upscale_factor,
+ activation=self.activation,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ elif self.block_type == "SymmetricConvNeXtBlock":
+ if self.latent_channels is None:
+ self.latent_channels = 1
+ return SymmetricConvNeXtBlock(
+ in_channels=self.in_channels,
+ latent_channels=cast(int, self.latent_channels),
+ out_channels=self.out_channels,
+ kernel_size=self.kernel_size,
+ dilation=self.dilation,
+ upscale_factor=self.upscale_factor,
+ activation=self.activation,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ elif self.block_type == "TransposedConvUpsample":
+ return TransposedConvUpsample(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ upsampling=self.upsampling,
+ activation=self.activation,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+ else:
+ raise ValueError(f"Unsupported block type: {self.block_type}")
+
+
+class ConvGRUBlock(nn.Module):
+ """Class that implements a Convolutional GRU.
+
+ Code modified from:
+ https://github.com/happyjin/ConvGRU-pytorch/blob/master/convGRU.py
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ kernel_size=1,
+ enable_nhwc=False,
+ enable_healpixpad=False,
+ ):
+ """
+ Args:
+ in_channels: The number of input channels.
+ kernel_size: Size of the convolutional kernel.
+ enable_nhwc: Enable nhwc format, passed to wrapper.
+ enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper.
+ """
+ super().__init__()
+
+ self.channels = in_channels
+ self.conv_gates = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels + self.channels,
+ out_channels=2 * self.channels, # for update_gate, reset_gate respectively
+ kernel_size=kernel_size,
+ padding="same",
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ self.conv_can = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels + self.channels,
+ out_channels=self.channels, # for candidate neural memory
+ kernel_size=kernel_size,
+ padding="same",
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ self.h = th.zeros(1, 1, 1, 1)
+
+ def forward(self, inputs):
+ """Forward pass of the ConvGRUBlock.
+
+ Args:
+ inputs: Input to the forward pass.
+
+ Returns:
+ th.Tensor: Result of the forward pass.
+ """
+ if inputs.shape != self.h.shape:
+ self.h = th.zeros_like(inputs)
+ combined = th.cat([inputs, self.h], dim=1)
+ combined_conv = self.conv_gates(combined)
+
+ gamma, beta = th.split(combined_conv, self.channels, dim=1)
+ reset_gate = th.sigmoid(gamma)
+ update_gate = th.sigmoid(beta)
+
+ combined = th.cat([inputs, reset_gate * self.h], dim=1)
+ cc_cnm = self.conv_can(combined)
+ cnm = th.tanh(cc_cnm)
+
+ h_next = (1 - update_gate) * self.h + update_gate * cnm
+ self.h = h_next
+
+ return inputs + h_next
+
+ def reset(self):
+ """Reset the update gates."""
+ self.h = th.zeros_like(self.h)
+
+
+class ConvLSTMBlock(nn.Module):
+ """Convolutional LSTM block."""
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 1,
+ latent_channels: int = 1,
+ kernel_size: int = 3,
+ downscale_factor: int = 4,
+ upscale_factor: int = 4,
+ n_layers: int = 1,
+ latent_conv_size: int = 3, # Add latent_conv_size parameter
+ dilation: int = 1,
+ activation: nn.Module = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Args:
+ in_channels: The number of input channels.
+ out_channels: The number of output channels.
+ latent_channels: Number of latent channels.
+ kernel_size: Size of the convolutional kernel.
+ downscale_factor: Downscale factor.
+ upscale_factor: Upscale factor.
+ n_layers: Number of layers.
+ latent_conv_size: Size of latent convolution.
+ dilation: Spacing between kernel points.
+ activation: Activation function.
+ enable_nhwc: Enable nhwc format.
+ enable_healpixpad: If HEALPixPadding should be enabled.
+ """
+ super().__init__()
+ # Instantiate 1x1 conv to increase/decrease channel depth if necessary
+ # Skip connection for output
+ if in_channels == out_channels:
+ self.skip_module = lambda x: x # Identity-function required in forward pass
+ else:
+ self.skip_module = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=in_channels, # out channels describes the space of the output of conv here; but we have the output of LSTM which is the input layer size
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ # Convolution block
+ convblock = []
+ # 3x3 convolution increasing channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels
+ * 2, # accounts for the h layer, which is concatenated before convolution runs
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation)
+ # 3x3 convolution maintaining increased channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation)
+
+ # Now for the LSTM bit
+ self.channels = in_channels
+ self.lstm_gates = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=latent_channels * upscale_factor,
+ out_channels=self.channels
+ * 4, # for input_gate, forget_gate, cell_gate, output_gate respectively (LSTM)
+ kernel_size=kernel_size,
+ padding="same",
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ self.h = th.zeros(1, 1, 1, 1)
+ self.c = th.zeros(1, 1, 1, 1)
+ self.convblock = nn.Sequential(*convblock)
+
+ def forward(self, inputs):
+ """Forward pass of the ConvLSTMBlock.
+
+ Args:
+ x: Inputs to the forward pass.
+
+ Returns:
+ th.Tensor: Result of the forward pass.
+ """
+ if inputs.shape != self.h.shape:
+ self.h = th.zeros_like(inputs)
+ self.c = th.zeros_like(inputs)
+
+ combined = th.cat([inputs, self.h], dim=1)
+ conv_outputs = self.convblock(combined)
+
+ lstm_gates = self.lstm_gates(conv_outputs)
+
+ # Split the combined_conv into input_gate, forget_gate, cell_gate, output_gate
+ i, f, c_hat, o = th.split(lstm_gates, self.channels, dim=1)
+ input_gate = th.sigmoid(i)
+ forget_gate = th.sigmoid(f)
+ cell_gate = th.tanh(c_hat)
+ output_gate = th.sigmoid(o)
+
+ self.c = forget_gate * self.c + input_gate * cell_gate
+ self.h = output_gate * th.tanh(self.c)
+
+ skip_connection = self.skip_module(inputs)
+ return skip_connection + self.h
+
+ def reset(self):
+ self.h = th.zeros_like(self.h)
+ self.c = th.zeros_like(self.c)
+
+
+# CONV BLOCKS
+
+
+class BasicConvBlock(nn.Module):
+ """Convolution block consisting of n subsequent convolutions and activations."""
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=1,
+ kernel_size=3,
+ dilation=1,
+ n_layers=1,
+ latent_channels=None,
+ activation=None,
+ enable_nhwc=False,
+ enable_healpixpad=False,
+ ):
+ """
+ Args:
+ in_channels: The number of input channels.
+ out_channels: The number of output channels.
+ kernel_size: Size of the convolutional kernel.
+ dilation: Spacing between kernel points, passed to nn.Conv2d.
+ n_layers: Number of convolutional layers.
+ latent_channels: Number of latent channels.
+ activation: ModuleConfig for activation function to use.
+ enable_nhwc: Enable nhwc format, passed to wrapper.
+ enable_healpixpad:: If HEALPixPadding should be enabled, passed to wrapper.
+ """
+ super().__init__()
+ if latent_channels is None:
+ latent_channels = max(in_channels, out_channels)
+ convblock = []
+ for n in range(n_layers):
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels if n == 0 else latent_channels,
+ out_channels=out_channels if n == n_layers - 1 else latent_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ self.convblock = nn.Sequential(*convblock)
+
+ def forward(self, x):
+ """Forward pass of the BasicConvBlock.
+
+ Args:
+ x: Inputs to the forward pass.
+
+ Returns:
+ th.Tensor: Result of the forward pass.
+ """
+ return self.convblock(x)
+
+
+class ConvNeXtBlock(nn.Module):
+ """A modified ConvNeXt network block as described in the paper
+ "A ConvNet for the 21st Century" (https://arxiv.org/pdf/2201.03545.pdf).
+
+ This block consists of a series of convolutional layers with optional activation functions,
+ and a residual connection.
+
+ Parameters:
+ skip_module: A module to align the input and output channels for the residual connection.
+ convblock: A sequential container of convolutional layers with optional activation functions.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ latent_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ upscale_factor: int = 4,
+ activation: Optional[CappedGELUConfig] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Initializes a ConvNeXtBlock instance with specified parameters.
+
+ Args:
+ in_channels: Number of input channels.
+ latent_channels: Number of latent channels used in the block.
+ out_channels: Number of output channels.
+ kernel_size: Size of the convolutional kernels.
+ dilation: Dilation rate for convolutions.
+ upscale_factor: Factor by which to upscale the number of latent channels.
+ activation: Configuration for the activation function used between layers.
+ enable_nhwc: Whether to enable NHWC format.
+ enable_healpixpad: Whether to enable HEALPixPadding.
+ """
+ super().__init__()
+
+ # Instantiate 1x1 conv to increase/decrease channel depth if necessary
+ if in_channels == out_channels:
+ self.skip_module = lambda x: x # Identity-function required in forward pass
+ else:
+ self.skip_module = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ # Convolution block
+ convblock = []
+ # 3x3 convolution increasing channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ # 3x3 convolution maintaining increased channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ # Linear postprocessing
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=out_channels,
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ self.convblock = nn.Sequential(*convblock)
+
+ def forward(self, x):
+ """Forward pass of the ConvNeXtBlock.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ The result of the forward pass.
+ """
+ return self.skip_module(x) + self.convblock(x)
+
+
+class DoubleConvNeXtBlock(nn.Module):
+ """A variant of the ConvNeXt block that includes two sequential ConvNeXt blocks within a single module.
+
+ Parameters:
+ skip_module1: A module to align the input and intermediate channels for the first residual connection.
+ skip_module2: A module to align the intermediate and output channels for the second residual connection.
+ convblock1: A sequential container of convolutional layers for the first ConvNeXt block.
+ convblock2: A sequential container of convolutional layers for the second ConvNeXt block.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ upscale_factor: int = 4,
+ latent_channels: int = 1,
+ activation: Optional[CappedGELUConfig] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Initializes a DoubleConvNeXtBlock instance with specified parameters.
+
+ Args:
+ in_channels: Number of input channels (default is 3).
+ out_channels: Number of output channels (default is 1).
+ kernel_size: Size of the convolutional kernels (default is 3).
+ dilation: Dilation rate for convolutions (default is 1).
+ upscale_factor: Factor by which to upscale the number of latent channels (default is 4).
+ latent_channels: Number of latent channels used in the block (default is 1).
+ activation: Configuration for the activation function used between layers (default is None).
+ enable_nhwc: Whether to enable NHWC format (default is False).
+ enable_healpixpad: Whether to enable HEALPixPadding (default is False).
+ """
+ super().__init__()
+
+ if in_channels == int(latent_channels):
+ self.skip_module1 = (
+ lambda x: x
+ ) # Identity-function required in forward pass
+ else:
+ self.skip_module1 = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=int(latent_channels),
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ if out_channels == int(latent_channels):
+ self.skip_module2 = (
+ lambda x: x
+ ) # Identity-function required in forward pass
+ else:
+ self.skip_module2 = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=out_channels,
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+
+ # 1st ConvNeXt block, the output of this one remains internal
+ convblock1 = []
+ # 3x3 convolution establishing latent channels channels
+ convblock1.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=int(latent_channels),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock1.append(activation.build())
+ # 1x1 convolution establishing increased channels
+ convblock1.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock1.append(activation.build())
+ # 1x1 convolution returning to latent channels
+ convblock1.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=int(latent_channels),
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock1.append(activation.build())
+ self.convblock1 = nn.Sequential(*convblock1)
+
+ # 2nd ConNeXt block, takes the output of the first convnext block
+ convblock2 = []
+ # 3x3 convolution establishing latent channels channels
+ convblock2.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=int(latent_channels),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock2.append(activation.build())
+ # 1x1 convolution establishing increased channels
+ convblock2.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock2.append(activation.build())
+ # 1x1 convolution reducing to output channels
+ convblock2.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=out_channels,
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock2.append(activation.build())
+ self.convblock2 = nn.Sequential(*convblock2)
+
+ def forward(self, x):
+ """Forward pass of the DoubleConvNextBlock
+ Args:
+ x: inputs to the forward pass
+ Returns:
+ result of the forward pass
+ """
+ # internal convnext result
+ x1 = self.skip_module1(x) + self.convblock1(x)
+ # return second convnext result
+ return self.skip_module2(x1) + self.convblock2(x1)
+
+
+class SymmetricConvNeXtBlock(nn.Module):
+ """A symmetric variant of the ConvNeXt block, with convolutional layers mirrored
+ around a central axis for symmetric feature extraction.
+
+ Parameters:
+ skip_module1: A module to align the input and intermediate channels for the first residual connection.
+ skip_module2: A module to align the intermediate and output channels for the second residual connection.
+ convblock1: A sequential container of convolutional layers for the symmetric ConvNeXt block.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ latent_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ upscale_factor: int = 4,
+ activation: Optional[CappedGELUConfig] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Initializes a SymmetricConvNeXtBlock instance with specified parameters.
+
+ Args:
+ in_channels: Number of input channels (default is 3).
+ out_channels: Number of output channels (default is 1).
+ kernel_size: Size of the convolutional kernels (default is 3).
+ dilation: Dilation rate for convolutions (default is 1).
+ upscale_factor: Upscale factor.
+ latent_channels: Number of latent channels used in the block (default is 1).
+ activation: Configuration for the activation function used between layers (default is None).
+ enable_nhwc: Whether to enable NHWC format (default is False).
+ enable_healpixpad: Whether to enable HEALPixPadding (default is False).
+ """
+ if in_channels == int(latent_channels):
+ self.skip_module = lambda x: x # Identity-function required in forward pass
+ else:
+ self.skip_module = HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+
+ # 1st ConvNeXt block, the output of this one remains internal
+ convblock = []
+ # 3x3 convolution establishing latent channels channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=in_channels,
+ out_channels=int(latent_channels),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ # 1x1 convolution establishing increased channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=int(latent_channels * upscale_factor),
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ # 1x1 convolution returning to latent channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels * upscale_factor),
+ out_channels=int(latent_channels),
+ kernel_size=1,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ # 3x3 convolution from latent channels to latent channels
+ convblock.append(
+ HEALPixLayer(
+ layer=th.nn.Conv2d,
+ in_channels=int(latent_channels),
+ out_channels=out_channels, # int(latent_channels),
+ kernel_size=kernel_size,
+ dilation=dilation,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ convblock.append(activation.build())
+ self.convblock = nn.Sequential(*convblock)
+
+ def forward(self, x):
+ """Forward pass of the SymmetricConvNextBlock
+ Args:
+ x: inputs to the forward pass
+ Returns:
+ result of the forward pass
+ """
+ # residual connection with reshaped inpute and output of conv block
+ return self.skip_module(x) + self.convblock(x)
+
+
+class TransposedConvUpsample(nn.Module):
+ """Wrapper for upsampling with a transposed convolution using HEALPix or other tensor data.
+
+ This class wraps the `nn.ConvTranspose2d` class to handle tensor data with
+ HEALPix or other geometry layers.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 1,
+ upsampling: int = 2,
+ activation: Optional[CappedGELUConfig] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Args:
+ in_channels: The number of input channels.
+ out_channels: The number of output channels.
+ upsampling: Size used for upsampling.
+ activation: ModuleConfig for the activation function used in upsampling.
+ enable_nhwc: Enable nhwc format, passed to wrapper.
+ enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper.
+ """
+ super().__init__()
+ upsampler = []
+ # Upsample transpose conv
+ upsampler.append(
+ HEALPixLayer(
+ layer=nn.ConvTranspose2d,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=upsampling,
+ stride=upsampling,
+ padding=0,
+ enable_nhwc=enable_nhwc,
+ enable_healpixpad=enable_healpixpad,
+ )
+ )
+ if activation is not None:
+ upsampler.append(activation.build())
+ self.upsampler = nn.Sequential(*upsampler)
+
+ def forward(self, x):
+ """Forward pass of the TransposedConvUpsample layer.
+
+ Args:
+ x: The values to upsample.
+
+ Returns:
+ th.Tensor: The upsampled values.
+ """
+ return self.upsampler(x)
+
+
+# Helpers
+
+
+class Interpolate(nn.Module):
+ """Helper class for interpolation.
+
+ This class handles interpolation, storing scale factor and mode for
+ `nn.functional.interpolate`.
+ """
+
+ def __init__(self, scale_factor: Union[int, Tuple], mode: str = "nearest"):
+ """
+ Args:
+ scale_factor: Multiplier for spatial size, passed to `nn.functional.interpolate`.
+ mode, : Interpolation mode used for upsampling, passed to `nn.functional.interpolate`.
+ """
+ super().__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+
+ def forward(self, inputs):
+ """Forward pass of the Interpolate layer.
+
+ Args:
+ inputs: Inputs to interpolate.
+
+ Returns:
+ th.Tensor: The interpolated values.
+ """
+ return self.interp(inputs, scale_factor=self.scale_factor, mode=self.mode)
diff --git a/fme/fme/ace/models/healpix/healpix_decoder.py b/fme/fme/ace/models/healpix/healpix_decoder.py
new file mode 100644
index 0000000..5dd5f15
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_decoder.py
@@ -0,0 +1,187 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from typing import List, Optional, Sequence
+
+import torch as th
+import torch.nn as nn
+
+from .healpix_blocks import ConvBlockConfig, RecurrentBlockConfig
+
+
+@dataclasses.dataclass
+class UNetDecoderConfig:
+ """
+ Configuration for the UNet Decoder.
+
+ Parameters:
+ conv_block: Configuration for the convolutional block.
+ up_sampling_block: Configuration for the up-sampling block.
+ output_layer: Configuration for the output layer block.
+ recurrent_block: Configuration for the recurrent block, by default None.
+ n_channels: Number of channels for each layer, by default (34, 68, 136).
+ n_layers: Number of layers in each block, by default (1, 2, 2).
+ output_channels: Number of output channels, by default 1.
+ dilations: List of dilation rates for the layers, by default None.
+ enable_nhwc: Flag to enable NHWC data format, by default False.
+ enable_healpixpad: Flag to enable HEALPix padding, by default False.
+ """
+
+ conv_block: ConvBlockConfig
+ up_sampling_block: ConvBlockConfig
+ output_layer: ConvBlockConfig
+ recurrent_block: Optional[RecurrentBlockConfig] = None
+ n_channels: List[int] = dataclasses.field(default_factory=lambda: [34, 68, 136])
+ n_layers: List[int] = dataclasses.field(default_factory=lambda: [1, 2, 2])
+ output_channels: int = 1
+ dilations: Optional[list] = None
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+
+ def build(self) -> nn.Module:
+ """
+ Builds the UNet Decoder model.
+
+ Returns:
+ UNet Decoder model.
+ """
+ return UNetDecoder(
+ conv_block=self.conv_block,
+ up_sampling_block=self.up_sampling_block,
+ output_layer=self.output_layer,
+ recurrent_block=self.recurrent_block,
+ n_channels=self.n_channels,
+ n_layers=self.n_layers,
+ output_channels=self.output_channels,
+ dilations=self.dilations,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+
+
+class UNetDecoder(nn.Module):
+ """Generic UNetDecoder that can be applied to arbitrary meshes."""
+
+ def __init__(
+ self,
+ conv_block: ConvBlockConfig,
+ up_sampling_block: ConvBlockConfig,
+ output_layer: ConvBlockConfig,
+ recurrent_block: Optional[RecurrentBlockConfig] = None,
+ n_channels: Sequence = (64, 32, 16),
+ n_layers: Sequence = (1, 2, 2),
+ output_channels: int = 1,
+ dilations: Optional[list] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Initialize the UNetDecoder.
+
+ Args:
+ conv_block: Configuration for the convolutional block.
+ up_sampling_block: Configuration for the upsampling block.
+ output_layer: Configuration for the output layer.
+ recurrent_block: Configuration for the recurrent block. If None, recurrent blocks are not used.
+ n_channels: Sequence specifying the number of channels in each decoder layer.
+ n_layers: Sequence specifying the number of layers in each block.
+ output_channels: Number of output channels.
+ dilations: List of dilations to use for the convolutional blocks.
+ enable_nhwc: If True, use channel last format.
+ enable_healpixpad: If True, use the healpixpad library if installed.
+ """
+ super().__init__()
+ self.channel_dim = 1
+
+ if dilations is None:
+ dilations = [1 for _ in range(len(n_channels))]
+
+ self.decoder = []
+ for n, curr_channel in enumerate(n_channels):
+ up_sample_module = None
+ if n != 0:
+ up_sampling_block.in_channels = curr_channel
+ up_sampling_block.out_channels = curr_channel
+ up_sampling_block.enable_nhwc = enable_nhwc
+ up_sampling_block.enable_healpixpad = enable_healpixpad
+ up_sample_module = up_sampling_block.build()
+
+ next_channel = (
+ n_channels[n + 1] if n < len(n_channels) - 1 else n_channels[-1]
+ )
+
+ conv_block.in_channels = curr_channel * 2 if n > 0 else curr_channel
+ conv_block.latent_channels = curr_channel
+ conv_block.out_channels = next_channel
+ conv_block.dilation = dilations[n]
+ conv_block.n_layers = n_layers[n]
+ conv_block.enable_nhwc = enable_nhwc
+ conv_block.enable_healpixpad = enable_healpixpad
+ conv_module = conv_block.build()
+
+ rec_module = None
+ if recurrent_block is not None:
+ recurrent_block.in_channels = next_channel
+ recurrent_block.enable_healpixpad = enable_healpixpad
+ rec_module = recurrent_block.build()
+
+ self.decoder.append(
+ nn.ModuleDict(
+ {
+ "upsamp": up_sample_module,
+ "conv": conv_module,
+ "recurrent": rec_module,
+ }
+ )
+ )
+
+ self.decoder = nn.ModuleList(self.decoder)
+
+ output_layer.in_channels = curr_channel
+ output_layer.out_channels = output_channels
+ output_layer.dilation = dilations[-1]
+ output_layer.enable_nhwc = enable_nhwc
+ output_layer.enable_healpixpad = enable_healpixpad
+
+ self.output_layer = output_layer.build()
+
+ def forward(self, inputs):
+ """
+ Forward pass of the UNetDecoder.
+
+ Args:
+ inputs: The inputs to the forward pass.
+
+ Returns:
+ The decoded values.
+ """
+ x = inputs[-1]
+ for n, layer in enumerate(self.decoder):
+ if layer["upsamp"] is not None:
+ up = layer["upsamp"](x)
+ x = th.cat([up, inputs[-1 - n]], dim=self.channel_dim)
+ x = layer["conv"](x)
+ if layer["recurrent"] is not None:
+ x = layer["recurrent"](x)
+ return self.output_layer(x)
+
+ def reset(self):
+ """Resets the state of the decoder layers."""
+ for layer in self.decoder:
+ if layer["recurrent"] is not None:
+ layer["recurrent"].reset()
diff --git a/fme/fme/ace/models/healpix/healpix_encoder.py b/fme/fme/ace/models/healpix/healpix_encoder.py
new file mode 100644
index 0000000..8471abb
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_encoder.py
@@ -0,0 +1,149 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from typing import List, Optional, Sequence
+
+import torch.nn as nn
+
+from fme.ace.models.healpix.healpix_activations import DownsamplingBlockConfig
+
+from .healpix_blocks import ConvBlockConfig
+
+
+@dataclasses.dataclass
+class UNetEncoderConfig:
+ """
+ Configuration for the UNet Encoder.
+
+ Parameters:
+ conv_block: Configuration for the convolutional block.
+ down_sampling_block: Configuration for the down-sampling block.
+ input_channels: Number of input channels, by default 3.
+ n_channels: Number of channels for each layer, by default (136, 68, 34).
+ n_layers: Number of layers in each block, by default (2, 2, 1).
+ dilations: List of dilation rates for the layers, by default None.
+ enable_nhwc: Flag to enable NHWC data format, by default False.
+ enable_healpixpad: Flag to enable HEALPix padding, by default False.
+ """
+
+ conv_block: ConvBlockConfig
+ down_sampling_block: DownsamplingBlockConfig
+ input_channels: int = 3
+ n_channels: List[int] = dataclasses.field(default_factory=lambda: [136, 68, 34])
+ n_layers: List[int] = dataclasses.field(default_factory=lambda: [2, 2, 1])
+ dilations: Optional[list] = None
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+
+ def build(self) -> nn.Module:
+ """
+ Builds the UNet Encoder model.
+
+ Returns:
+ UNet Encoder model.
+ """
+ return UNetEncoder(
+ conv_block=self.conv_block,
+ down_sampling_block=self.down_sampling_block,
+ input_channels=self.input_channels,
+ n_channels=self.n_channels,
+ n_layers=self.n_layers,
+ dilations=self.dilations,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
+
+
+class UNetEncoder(nn.Module):
+ """Generic UNetEncoder that can be applied to arbitrary meshes."""
+
+ def __init__(
+ self,
+ conv_block: ConvBlockConfig,
+ down_sampling_block: DownsamplingBlockConfig,
+ input_channels: int = 3,
+ n_channels: Sequence = (16, 32, 64),
+ n_layers: Sequence = (2, 2, 1),
+ dilations: Optional[list] = None,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ ):
+ """
+ Args:
+ conv_block: config for the convolutional block
+ down_sampling_block: DownsamplingBlockConfig for the downsample block
+ input_channels: # of input channels
+ n_channels: # of channels in each encoder layer
+ n_layers:, # of layers to use for the convolutional blocks
+ dilations: list of dilations to use for the the convolutional blocks
+ enable_nhwc: if channel last format should be used
+ enable_healpixpad: if healpixpad library should be used (true if installed)
+ """
+ super().__init__()
+ self.n_channels = n_channels
+
+ if dilations is None:
+ # Defaults to [1, 1, 1...] in accordance with the number of unet levels
+ dilations = [1 for _ in range(len(n_channels))]
+
+ # Build encoder
+ old_channels = input_channels
+ self.encoder = []
+ for n, curr_channel in enumerate(n_channels):
+ modules = list()
+ if n > 0:
+ down_sampling_block.enable_nhwc = enable_nhwc
+ down_sampling_block.enable_healpixpad = enable_healpixpad
+ modules.append(
+ down_sampling_block.build() # Shapes are not used in these calls.
+ )
+
+ # Set up conv block
+ conv_block.in_channels = old_channels
+ conv_block.latent_channels = curr_channel
+ conv_block.out_channels = curr_channel
+ conv_block.dilation = dilations[n]
+ conv_block.n_layers = n_layers[n]
+ conv_block.enable_nhwc = enable_nhwc
+ conv_block.enable_healpixpad = enable_healpixpad
+ modules.append(conv_block.build()) # Shapes are not used in these calls.
+ old_channels = curr_channel
+
+ self.encoder.append(nn.Sequential(*modules))
+
+ self.encoder = nn.ModuleList(self.encoder)
+
+ def forward(self, inputs: Sequence) -> Sequence:
+ """
+ Forward pass of the HEALPix Unet encoder
+
+ Args:
+ inputs: The inputs to enccode
+
+ Returns:
+ The encoded values
+ """
+ outputs = []
+ for layer in self.encoder:
+ outputs.append(layer(inputs))
+ inputs = outputs[-1]
+ return outputs
+
+ def reset(self):
+ """Resets the state of the decoder layers"""
+ pass
diff --git a/fme/fme/ace/models/healpix/healpix_layers.py b/fme/fme/ace/models/healpix/healpix_layers.py
new file mode 100644
index 0000000..edbeb39
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_layers.py
@@ -0,0 +1,493 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file contains padding and convolution classes to perform according operations on the twelve faces of the HEALPix.
+
+
+ HEALPix Face order 3D array representation
+ -----------------
+-------------------------- //\\ //\\ //\\ //\\ | | | | |
+|| 0 | 1 | 2 | 3 || // \\// \\// \\// \\ |0 |1 |2 |3 |
+|\\ //\\ //\\ //\\ //| /\\0 //\\1 //\\2 //\\3 // -----------------
+| \\// \\// \\// \\// | // \\// \\// \\// \\// | | | | |
+|4//\\5 //\\6 //\\7 //\\4| \\4//\\5 //\\6 //\\7 //\\ |4 |5 |6 |7 |
+|// \\// \\// \\// \\| \\/ \\// \\// \\// \\ -----------------
+|| 8 | 9 | 10 | 11 | \\8 //\\9 //\\10//\\11// | | | | |
+-------------------------- \\// \\// \\// \\// |8 |9 |10 |11 |
+ -----------------
+ "\\" are top and bottom, whereas
+ "//" are left and right borders
+
+
+Details on the HEALPix can be found at https://iopscience.iop.org/article/10.1086/427976
+
+"""
+
+import logging
+import sys
+
+import torch as th
+import torch.nn as nn
+
+
+class HEALPixFoldFaces(nn.Module):
+ """Class that folds the faces of a HealPIX tensor"""
+
+ def __init__(self, enable_nhwc: bool = False):
+ """
+ Args:
+ enable_nhwc: Use nhwc format instead of nchw format
+ """
+ super().__init__()
+ self.enable_nhwc = enable_nhwc
+
+ def forward(self, tensor: th.Tensor) -> th.Tensor:
+ """
+ Forward pass that folds a HEALPix tensor
+ [B, F, C, H, W] -> [B*F, C, H, W]
+
+ Args:
+ tensor: The tensor to fold
+
+ Returns:
+ th.Tensor: the folded tensor
+
+ """
+ N, F, C, H, W = tensor.shape
+ tensor = th.reshape(tensor, shape=(N * F, C, H, W))
+
+ if self.enable_nhwc:
+ tensor = tensor.to(memory_format=th.channels_last)
+
+ return tensor
+
+
+class HEALPixUnfoldFaces(nn.Module):
+ """Class that unfolds the faces of a HealPIX tensor"""
+
+ def __init__(self, num_faces=12, enable_nhwc=False):
+ """
+ Args:
+ num_faces: The number of faces on the grid, default 12
+ enable_nhwc: If nhwc format is being used, default False
+ """
+ super().__init__()
+ self.num_faces = num_faces
+ self.enable_nhwc = enable_nhwc
+
+ def forward(self, tensor: th.Tensor) -> th.Tensor:
+ """
+ Forward pass that unfolds a HEALPix tensor
+ [B*F, C, H, W] -> [B, F, C, H, W]
+
+ Args:
+ tensor: the tensor to unfold
+
+ Returns:
+ The unfolded tensor
+
+ """
+ NF, C, H, W = tensor.shape
+ N = int(NF / self.num_faces)
+ tensor = th.reshape(tensor, shape=(N, self.num_faces, C, H, W))
+
+ return tensor
+
+
+class HEALPixPadding(nn.Module):
+ """
+ Padding layer for data on a HEALPix sphere. The requirements for using this layer are as follows:
+ - The last three dimensions are (face=12, height, width)
+ - The first four indices in the faces dimension [0, 1, 2, 3] are the faces on the northern hemisphere
+ - The second four indices in the faces dimension [4, 5, 6, 7] are the faces on the equator
+ - The last four indices in the faces dimension [8, 9, 10, 11] are the faces on the southern hemisphere
+
+ Orientation and arrangement of the HEALPix faces are outlined above.
+ """
+
+ def __init__(self, padding: int, enable_nhwc: bool = False):
+ """
+ Args:
+ padding: The padding size
+ enable_nhwc: If nhwc format is being used, default False
+ """
+ super().__init__()
+ self.p = padding
+ self.d = [-2, -1]
+ self.enable_nhwc = enable_nhwc
+ if not isinstance(padding, int) or padding < 1:
+ raise ValueError(
+ f"invalid value for 'padding', expected int > 0 but got {padding}"
+ )
+
+ self.fold = HEALPixFoldFaces(enable_nhwc=self.enable_nhwc)
+ self.unfold = HEALPixUnfoldFaces(num_faces=12, enable_nhwc=self.enable_nhwc)
+
+ def forward(self, data: th.Tensor) -> th.Tensor:
+ """
+ Pad each face consistently with its according neighbors in the HEALPix (see ordering and neighborhoods above).
+ Assumes the Tensor is folded
+
+ Args:
+ data: The input tensor of shape [..., F, H, W] where each face is to be padded in its HPX context
+
+ Returns:
+ The padded tensor where each face's height and width are increased by 2*p
+ """
+ th.cuda.nvtx.range_push("HEALPixPadding:forward")
+
+ # unfold faces from batch dim
+ data = self.unfold(data)
+
+ # Extract the twelve faces (as views of the original tensors)
+ f00, f01, f02, f03, f04, f05, f06, f07, f08, f09, f10, f11 = [
+ th.squeeze(x, dim=1)
+ for x in th.split(tensor=data, split_size_or_sections=1, dim=1)
+ ]
+
+ # Assemble the four padded faces on the northern hemisphere
+ p00 = self.pn(
+ c=f00, t=f01, tl=f02, lft=f03, bl=f03, b=f04, br=f08, rgt=f05, tr=f01
+ )
+ p01 = self.pn(
+ c=f01, t=f02, tl=f03, lft=f00, bl=f00, b=f05, br=f09, rgt=f06, tr=f02
+ )
+ p02 = self.pn(
+ c=f02, t=f03, tl=f00, lft=f01, bl=f01, b=f06, br=f10, rgt=f07, tr=f03
+ )
+ p03 = self.pn(
+ c=f03, t=f00, tl=f01, lft=f02, bl=f02, b=f07, br=f11, rgt=f04, tr=f00
+ )
+
+ # Assemble the four padded faces on the equator
+ p04 = self.pe(
+ c=f04,
+ t=f00,
+ tl=self.tl(f00, f03),
+ lft=f03,
+ bl=f07,
+ b=f11,
+ br=self.br(f11, f08),
+ rgt=f08,
+ tr=f05,
+ )
+ p05 = self.pe(
+ c=f05,
+ t=f01,
+ tl=self.tl(f01, f00),
+ lft=f00,
+ bl=f04,
+ b=f08,
+ br=self.br(f08, f09),
+ rgt=f09,
+ tr=f06,
+ )
+ p06 = self.pe(
+ c=f06,
+ t=f02,
+ tl=self.tl(f02, f01),
+ lft=f01,
+ bl=f05,
+ b=f09,
+ br=self.br(f09, f10),
+ rgt=f10,
+ tr=f07,
+ )
+ p07 = self.pe(
+ c=f07,
+ t=f03,
+ tl=self.tl(f03, f02),
+ lft=f02,
+ bl=f06,
+ b=f10,
+ br=self.br(f10, f11),
+ rgt=f11,
+ tr=f04,
+ )
+
+ # Assemble the four padded faces on the southern hemisphere
+ p08 = self.ps(
+ c=f08, t=f05, tl=f00, lft=f04, bl=f11, b=f11, br=f10, rgt=f09, tr=f09
+ )
+ p09 = self.ps(
+ c=f09, t=f06, tl=f01, lft=f05, bl=f08, b=f08, br=f11, rgt=f10, tr=f10
+ )
+ p10 = self.ps(
+ c=f10, t=f07, tl=f02, lft=f06, bl=f09, b=f09, br=f08, rgt=f11, tr=f11
+ )
+ p11 = self.ps(
+ c=f11, t=f04, tl=f03, lft=f07, bl=f10, b=f10, br=f09, rgt=f08, tr=f08
+ )
+
+ res = th.stack(
+ (p00, p01, p02, p03, p04, p05, p06, p07, p08, p09, p10, p11), dim=1
+ )
+
+ # fold faces into batch dim
+ res = self.fold(res)
+
+ th.cuda.nvtx.range_pop()
+
+ return res
+
+ def pn(
+ self,
+ c: th.Tensor,
+ t: th.Tensor,
+ tl: th.Tensor,
+ lft: th.Tensor,
+ bl: th.Tensor,
+ b: th.Tensor,
+ br: th.Tensor,
+ rgt: th.Tensor,
+ tr: th.Tensor,
+ ) -> th.Tensor:
+ """
+ Applies padding to a northern hemisphere face c under consideration of its given neighbors.
+
+ Args:
+ c: The central face and tensor that is subject for padding
+ t: The top neighboring face tensor
+ tl: The top left neighboring face tensor
+ lft: The left neighboring face tensor
+ bl: The bottom left neighboring face tensor
+ b: The bottom neighboring face tensor
+ br: The bottom right neighboring face tensor
+ rgt: The right neighboring face tensor
+ tr: The top right neighboring face tensor
+
+ Returns:
+ The padded tensor p
+ """
+ p = self.p # Padding size
+ d = self.d # Dimensions for rotations
+
+ # Start with top and bottom to extend the height of the c tensor
+ c = th.cat((t.rot90(1, d)[..., -p:, :], c, b[..., :p, :]), dim=-2)
+
+ # Construct the left and right pads including the corner faces
+ left = th.cat(
+ (
+ tl.rot90(2, d)[..., -p:, -p:],
+ lft.rot90(-1, d)[..., -p:],
+ bl[..., :p, -p:],
+ ),
+ dim=-2,
+ )
+ right = th.cat((tr[..., -p:, :p], rgt[..., :p], br[..., :p, :p]), dim=-2)
+
+ return th.cat((left, c, right), dim=-1)
+
+ def pe(
+ self,
+ c: th.Tensor,
+ t: th.Tensor,
+ tl: th.Tensor,
+ lft: th.Tensor,
+ bl: th.Tensor,
+ b: th.Tensor,
+ br: th.Tensor,
+ rgt: th.Tensor,
+ tr: th.Tensor,
+ ) -> th.Tensor:
+ """
+ Applies padding to an equatorial face c under consideration of its given neighbors.
+
+ Args:
+ c: The central face and tensor that is subject for padding
+ t: The top neighboring face tensor
+ tl: The top left neighboring face tensor
+ lft: The left neighboring face tensor
+ bl: The bottom left neighboring face tensor
+ b: The bottom neighboring face tensor
+ br: The bottom right neighboring face tensor
+ rgt: The right neighboring face tensor
+ tr: The top right neighboring face tensor
+
+ Returns
+ -------
+ th.Tensor:
+ The padded tensor p
+ """
+ p = self.p # Padding size
+
+ # Start with top and bottom to extend the height of the c tensor
+ c = th.cat((t[..., -p:, :], c, b[..., :p, :]), dim=-2)
+
+ # Construct the left and right pads including the corner faces
+ left = th.cat((tl[..., -p:, -p:], lft[..., -p:], bl[..., :p, -p:]), dim=-2)
+ right = th.cat((tr[..., -p:, :p], rgt[..., :p], br[..., :p, :p]), dim=-2)
+
+ return th.cat((left, c, right), dim=-1)
+
+ def ps(
+ self,
+ c: th.Tensor,
+ t: th.Tensor,
+ tl: th.Tensor,
+ lft: th.Tensor,
+ bl: th.Tensor,
+ b: th.Tensor,
+ br: th.Tensor,
+ rgt: th.Tensor,
+ tr: th.Tensor,
+ ) -> th.Tensor:
+ """
+ Applies padding to a southern hemisphere face c under consideration of its given neighbors.
+
+ Args:
+ c: The central face and tensor that is subject for padding
+ t: The top neighboring face tensor
+ tl: The top left neighboring face tensor
+ lft: The left neighboring face tensor
+ bl: The bottom left neighboring face tensor
+ b: The bottom neighboring face tensor
+ br: The bottom right neighboring face tensor
+ rgt: The right neighboring face tensor
+ tr: The top right neighboring face tensor
+
+ Returns:
+ The padded tensor p
+ """
+ p = self.p # Padding size
+ d = self.d # Dimensions for rotations
+
+ # Start with top and bottom to extend the height of the c tensor
+ c = th.cat((t[..., -p:, :], c, b.rot90(1, d)[..., :p, :]), dim=-2)
+
+ # Construct the left and right pads including the corner faces
+ left = th.cat((tl[..., -p:, -p:], lft[..., -p:], bl[..., :p, -p:]), dim=-2)
+ right = th.cat(
+ (tr[..., -p:, :p], rgt.rot90(-1, d)[..., :p], br.rot90(2, d)[..., :p, :p]),
+ dim=-2,
+ )
+
+ return th.cat((left, c, right), dim=-1)
+
+ def tl(self, top: th.Tensor, lft: th.Tensor) -> th.Tensor:
+ """
+ Assembles the top left corner of a center face in the cases where no according top left face is defined on the
+ HPX.
+
+ Args:
+ top: The face above the center face
+ lft: The face left of the center face
+
+ Returns:
+ The assembled top left corner (only the sub-part that is required for padding)
+ """
+ ret = th.zeros_like(top)[..., : self.p, : self.p] # super ugly but super fast
+
+ # Bottom left point
+ ret[..., -1, -1] = 0.5 * top[..., -1, 0] + 0.5 * lft[..., 0, -1]
+
+ # Remaining points
+ for i in range(1, self.p):
+ ret[..., -i - 1, -i:] = top[
+ ..., -i - 1, :i
+ ] # Filling top right above main diagonal
+ ret[..., -i:, -i - 1] = lft[
+ ..., :i, -i - 1
+ ] # Filling bottom left below main diagonal
+ ret[..., -i - 1, -i - 1] = (
+ 0.5 * top[..., -i - 1, 0] + 0.5 * lft[..., 0, -i - 1]
+ ) # Diagonal
+
+ return ret
+
+ def br(self, b: th.Tensor, r: th.Tensor) -> th.Tensor:
+ """
+ Assembles the bottom right corner of a center face in the cases where no according bottom right face is defined
+ on the HPX.
+
+ Args:
+ b: The face below the center face
+ r: The face right of the center face
+
+ Returns:
+ The assembled bottom right corner (only the sub-part that is required for padding)
+ """
+ ret = th.zeros_like(b)[..., : self.p, : self.p]
+
+ # Top left point
+ ret[..., 0, 0] = 0.5 * b[..., 0, -1] + 0.5 * r[..., -1, 0]
+
+ # Remaining points
+ for i in range(1, self.p):
+ ret[..., :i, i] = r[..., -i:, i] # Filling top right above main diagonal
+ ret[..., i, :i] = b[..., i, -i:] # Filling bottom left below main diagonal
+ ret[..., i, i] = 0.5 * b[..., i, -1] + 0.5 * r[..., -1, i] # Diagonal
+
+ return ret
+
+
+class HEALPixLayer(nn.Module):
+ """Pytorch module for applying any base torch Module on a HEALPix tensor. Expects all input/output tensors to have a
+ shape [..., 12, H, W], where 12 is the dimension of the faces.
+ """
+
+ def __init__(self, layer, **kwargs):
+ """
+ Args:
+ layer: Any torch layer function, e.g., nn.Conv2d
+ kwargs: The arguments that are passed to the torch layer function, e.g., kernel_size
+ """
+ super().__init__()
+ layers = []
+
+ if "enable_nhwc" in kwargs:
+ enable_nhwc = kwargs["enable_nhwc"]
+ del kwargs["enable_nhwc"]
+ else:
+ enable_nhwc = False
+
+ if "enable_healpixpad" in kwargs and kwargs["enable_healpixpad"]:
+ raise NotImplementedError(
+ "HEALPixPaddingv2 is not available in this environment"
+ )
+
+ if "enable_healpixpad" in kwargs:
+ del kwargs["enable_healpixpad"]
+
+ if not isinstance(layer, type) or not issubclass(layer, th.nn.Module):
+ raise TypeError(
+ f"Expected a subclass of torch.nn.Module, got {type(layer).__name__}"
+ )
+ # Define a HEALPixPadding layer if the given layer is a convolution layer
+ if layer.__bases__[0] is nn.modules.conv._ConvNd and kwargs["kernel_size"] > 1:
+ kwargs["padding"] = 0 # Disable native padding
+ kernel_size = 3 if "kernel_size" not in kwargs else kwargs["kernel_size"]
+ dilation = 1 if "dilation" not in kwargs else kwargs["dilation"]
+ padding = ((kernel_size - 1) // 2) * dilation
+ layers.append(HEALPixPadding(padding=padding, enable_nhwc=enable_nhwc))
+
+ layers.append(layer(**kwargs))
+ self.layers = nn.Sequential(*layers)
+
+ if enable_nhwc:
+ self.layers = self.layers.to(memory_format=th.channels_last)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ """
+ Performs the forward pass using the defined layer function and the given data.
+
+ :param x: The input tensor of shape [..., F=12, H, W]
+ :return: The output tensor of this HEALPix layer
+ """
+ res = self.layers(x)
+ return res
diff --git a/fme/fme/ace/models/healpix/healpix_recunet.py b/fme/fme/ace/models/healpix/healpix_recunet.py
new file mode 100644
index 0000000..56b8d78
--- /dev/null
+++ b/fme/fme/ace/models/healpix/healpix_recunet.py
@@ -0,0 +1,436 @@
+# flake8: noqa
+# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
+# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence
+
+import pandas as pd
+import torch as th
+import torch.nn as nn
+
+from .healpix_decoder import UNetDecoderConfig
+from .healpix_encoder import UNetEncoderConfig
+from .healpix_layers import HEALPixFoldFaces, HEALPixUnfoldFaces
+
+
+class HEALPixRecUNet(nn.Module):
+ """Deep Learning Weather Prediction (DLWP) recurrent UNet model on the HEALPix mesh."""
+
+ def __init__(
+ self,
+ encoder: UNetEncoderConfig,
+ decoder: UNetDecoderConfig,
+ input_channels: int,
+ output_channels: int,
+ prognostic_variables: int,
+ n_constants: int,
+ decoder_input_channels: int,
+ input_time_size: int,
+ output_time_size: int,
+ delta_time: str = "6h",
+ reset_cycle: str = "24h",
+ presteps: int = 1,
+ enable_nhwc: bool = False,
+ enable_healpixpad: bool = False,
+ couplings: list = [],
+ ):
+ """
+ Initialize the HEALPixRecUNet model.
+
+ Args:
+ encoder: UNetEncoderConfig
+ ModuleConfig of instantiable parameters for the U-net encoder.
+ decoder: UNetDecoderConfig
+ ModuleConfig of instantiable parameters for the U-net decoder.
+ input_channels: int
+ Number of input channels expected in the input array schema. Note this should be the
+ number of input variables in the data, NOT including data reshaping for the encoder part.
+ output_channels: int
+ Number of output channels expected in the output array schema, or output variables.
+ n_constants: int
+ Number of optional constants expected in the input arrays. If this is zero, no constants
+ should be provided as inputs to forward.
+ decoder_input_channels: int
+ Number of optional prescribed variables expected in the decoder input array
+ for both inputs and outputs. If this is zero, no decoder inputs should be provided as inputs to forward.
+ input_time_size: int
+ Number of time steps in the input array.
+ output_time_size: int
+ Number of time steps in the output array.
+ delta_time: str, optional
+ Hours between two consecutive data points.
+ reset_cycle: str, optional
+ Hours after which the recurrent states are reset to zero and re-initialized. Set np.infty
+ to never reset the hidden states.
+ presteps: int, optional
+ Number of model steps to initialize recurrent states.
+ enable_nhwc: bool, optional
+ Model with [N, H, W, C] instead of [N, C, H, W].
+ enable_healpixpad: bool, optional
+ Enable CUDA HEALPixPadding if installed.
+ couplings: list, optional
+ Sequence of dictionaries that describe coupling mechanisms. Currently unused in our production model;
+ but we want to keep this in the module definition, in case we bring our SST module
+ (which subclasses it) into the picture.
+ """
+ super().__init__()
+ self.channel_dim = 2 # Now 2 with [B, F, T*C, H, W]. Was 1 in old data format with [B, T*C, F, H, W]
+
+ self.input_channels = input_channels
+
+ if n_constants == 0 and decoder_input_channels == 0:
+ pass
+ # raise NotImplementedError(
+ # "support for models with no constant fields and no decoder inputs (TOA insolation) is not available at this time."
+ # )
+ if couplings is not None:
+ if len(couplings) > 0:
+ if n_constants == 0:
+ raise NotImplementedError(
+ "support for coupled models with no constant fields is not available at this time."
+ )
+ if decoder_input_channels == 0:
+ raise NotImplementedError(
+ "support for coupled models with no decoder inputs (TOA insolation) is not available at this time."
+ )
+ else:
+ couplings = []
+
+ # add coupled fields to input channels for model initialization
+ self.coupled_channels = self._compute_coupled_channels(couplings)
+ self.couplings = couplings
+ self.train_couplers = None
+ self.output_channels = output_channels
+ self.n_constants = n_constants
+ self.prognostic_variables = prognostic_variables
+ self.decoder_input_channels = decoder_input_channels
+ self.input_time_size = input_time_size
+ self.output_time_size = output_time_size
+ self.delta_t = int(pd.Timedelta(delta_time).total_seconds() // 3600)
+ self.reset_cycle = int(pd.Timedelta(reset_cycle).total_seconds() // 3600)
+ self.presteps = presteps
+ self.enable_nhwc = enable_nhwc
+ self.enable_healpixpad = enable_healpixpad
+
+ # Number of passes through the model, or a diagnostic model with only one output time
+ self.is_diagnostic = self.output_time_size == 1 and self.input_time_size > 1
+ if not self.is_diagnostic and (
+ self.output_time_size % self.input_time_size != 0
+ ):
+ raise ValueError(
+ f"'output_time_size' must be a multiple of 'input_time_size' (got "
+ f"{self.output_time_size} and {self.input_time_size})"
+ )
+
+ # Build the model layers
+ self.fold = HEALPixFoldFaces()
+ self.unfold = HEALPixUnfoldFaces(num_faces=12)
+ encoder.input_channels = self._compute_input_channels()
+ encoder.enable_nhwc = self.enable_nhwc
+ encoder.enable_healpixpad = self.enable_healpixpad
+ self.encoder = encoder.build()
+
+ self.encoder_depth = len(self.encoder.n_channels)
+ decoder.output_channels = self._compute_output_channels()
+ decoder.enable_nhwc = self.enable_nhwc
+ decoder.enable_healpixpad = self.enable_healpixpad
+ self.decoder = decoder.build()
+
+ @property
+ def integration_steps(self):
+ """Number of integration steps"""
+ return max(self.output_time_size // self.input_time_size, 1)
+
+ def _compute_input_channels(self) -> int:
+ """
+ Calculate total number of input channels in the model.
+
+ Returns:
+ int: The total number of input channels.
+ """
+ return (
+ self.input_time_size * (self.input_channels + self.decoder_input_channels)
+ + self.n_constants
+ + self.coupled_channels
+ )
+
+ def _compute_coupled_channels(self, couplings):
+ """
+ Get the number of coupled channels.
+
+ Args:
+ couplings: list
+ Sequence of dictionaries that describe coupling mechanisms.
+
+ Returns:
+ int: The number of coupled channels.
+ """
+ c_channels = 0
+ for c in couplings:
+ c_channels += len(c["params"]["variables"]) * len(
+ c["params"]["input_times"]
+ )
+ return c_channels
+
+ def _compute_output_channels(self) -> int:
+ """
+ Compute the total number of output channels in the model.
+
+ Returns:
+ int: The total number of output channels.
+ """
+ return (
+ 1 if self.is_diagnostic else self.input_time_size
+ ) * self.output_channels
+
+ # def _reshape_inputs(self, inputs: Sequence, step: int = 0) -> th.Tensor:
+ # """
+ # Returns a single tensor to pass into the model encoder/decoder. Squashes the time/channel dimension and
+ # concatenates in constants and decoder inputs.
+
+ # Args:
+ # inputs: list of expected input tensors (inputs, decoder_inputs, constants)
+ # step: step number in the sequence of integration_steps
+
+ # Returns:
+ # reshaped Tensor in expected shape for model encoder
+ # """
+
+ # # if len(self.couplings) > 0:
+ # # result = [
+ # # inputs[0].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ),
+ # # inputs[1][
+ # # :,
+ # # :,
+ # # slice(step * self.input_time_size, (step + 1) * self.input_time_size),
+ # # ...,
+ # # ].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ), # DI
+ # # inputs[2].expand(
+ # # *tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1])
+ # # ), # constants
+ # # inputs[3].permute(0, 2, 1, 3, 4), # coupled inputs
+ # # ]
+ # # res = th.cat(result, dim=self.channel_dim)
+
+ # # else:
+ # # if self.n_constants == 0:
+ # # result = [ # This logic changes for no insolation layer for the time being
+ # # inputs[0].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ),
+ # # # inputs[0].flatten(
+ # # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # # ),
+ # # # inputs[1][
+ # # # :,
+ # # # :,
+ # # # slice(
+ # # # step * self.input_time_size, (step + 1) * self.input_time_size
+ # # # ),
+ # # # ...,
+ # # # ].flatten(
+ # # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # # ), # DI
+ # # ]
+ # # res = th.cat(result, dim=self.channel_dim)
+
+ # # # fold faces into batch dim
+ # # res = self.fold(res)
+
+ # # return res
+
+ # # if self.decoder_input_channels == 0:
+ # # result = [
+ # # inputs[0].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ),
+ # # inputs[1].expand(
+ # # *tuple([inputs[0].shape[0]] + len(inputs[1].shape) * [-1])
+ # # ), # inputs
+ # # ]
+ # # print(f"result 1 is {result[0].shape}")
+ # # print(f"result 2 is {result[1].shape}")
+
+ # # res = th.cat(result, dim=self.channel_dim)
+
+ # # # fold faces into batch dim
+ # # res = self.fold(res)
+
+ # # return res
+
+ # # result = [
+ # # inputs[0].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ),
+ # # inputs[1][
+ # # :,
+ # # :,
+ # # slice(step * self.input_time_size, (step + 1) * self.input_time_size),
+ # # ...,
+ # # ].flatten(
+ # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1
+ # # ), # DI
+ # # inputs[2].expand(
+ # # *tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1])
+ # # ), # constants
+ # # ]
+ # # res = th.cat(result, dim=self.channel_dim)
+
+ # # fold faces into batch dim
+ # # res = self.fold(res)
+ # # res = self.fold(inputs)
+ # # return res
+
+ # def _reshape_outputs(self, outputs: th.Tensor) -> th.Tensor:
+ # """Returns a multiple tensors to from the model decoder.
+ # Splits the time/channel dimensions.
+
+ # Args:
+ # inputs: list of expected input tensors (inputs, decoder_inputs, constants)
+ # step: step number in the sequence of integration_steps
+
+ # Returns:
+ # reshaped Tensor in expected shape for model outputs
+ # """
+ # # unfold:
+ # outputs = self.unfold(outputs)
+
+ # return outputs
+
+ def _initialize_hidden(
+ self, inputs: Sequence, outputs: Sequence, step: int
+ ) -> None:
+ """Initialize the hidden layers
+
+ Args:
+ inputs: Inputs to use to initialize the hideen layers
+ outputs: Outputs to use to initialize the hideen layers
+ step: Current step number of the initialization
+ """
+ self.reset()
+ for prestep in range(self.presteps):
+ if step < self.presteps:
+ s = step + prestep
+ if len(self.couplings) > 0:
+ input_tensor = self.fold(
+ inputs=[
+ inputs[0][
+ :,
+ :,
+ s * self.input_time_size : (s + 1)
+ * self.input_time_size,
+ ]
+ ]
+ + list(inputs[1:3])
+ + [inputs[3][prestep]],
+ step=step + prestep,
+ )
+ else:
+ input_tensor = self.fold(
+ inputs=[
+ inputs[0][
+ :,
+ :,
+ s * self.input_time_size : (s + 1)
+ * self.input_time_size,
+ ]
+ ]
+ + list(inputs[1:]),
+ step=step + prestep,
+ )
+ else:
+ s = step - self.presteps + prestep
+ if len(self.couplings) > 0:
+ input_tensor = self.fold(
+ inputs=[outputs[s - 1]]
+ + list(inputs[1:3])
+ + [inputs[3][step - (prestep - self.presteps)]],
+ step=s + 1,
+ )
+ else:
+ input_tensor = self.fold(
+ inputs=[outputs[s - 1]] + list(inputs[1:]), step=s + 1
+ )
+ self.decoder(self.encoder(input_tensor))
+
+ def forward(self, inputs: th.Tensor, output_only_last=False) -> th.Tensor:
+ """
+ Forward pass of the HEALPixUnet
+
+ Args:
+ inputs: Inputs to the model, which is currently in the form [B, F, C, H, W].
+ (We assume that constants have been preprocessed to persist across batch)
+ (We also assume for now that the time is implied to be 1)
+
+ Originally, this was expected to be a sequence of the form [prognostics|TISR|constants]:
+ [B, F, T, C, H, W] the format for prognostics and TISR
+ [F, C, H, W] the format for constants
+
+ output_only_last: If only the last dimension of the outputs should
+ be returned
+
+ Returns:
+ Predicted outputs
+ """
+ self.reset() # will reset every step for now
+ # We need to make sure that the new input is the correct size, since we no longer have the ability
+ # to differentiate between the inputs, decoder inputs, and constants
+ if self._compute_input_channels() != inputs.shape[2]:
+ raise ValueError(
+ f"Expected input should have channels {self._compute_input_channels()},"
+ f" got {inputs.shape[2]}."
+ )
+
+ # The input logic gets really changed now that we just have a prognostic vars channel in the inputs.
+ # Basically we assume that all the batch-wise expansion for the inputs has already been done.
+
+ # (Re-)initialize recurrent hidden states
+ # if (step * (self.delta_t * self.input_time_size)) % self.reset_cycle == 0:
+ # self._initialize_hidden(inputs=inputs, outputs=outputs, step=step)
+ # Skipping this for now. We assume a single input time step, and resetting every step.
+ # s = self.presteps
+ input_tensor = self.fold(
+ inputs
+ ) # Padding happens in HEALPixPadding, which will
+ # unfold this tensor to handle it.
+
+ encodings = self.encoder(input_tensor)
+ decodings = self.decoder(encodings)
+
+ # Residual prediction
+ n_prognostic_channels = self.prognostic_variables * self.input_time_size
+ prognostic_outputs = (
+ input_tensor[:, :n_prognostic_channels]
+ + decodings[:, :n_prognostic_channels]
+ )
+
+ outputs_only = decodings[:, n_prognostic_channels:]
+
+ reshaped = th.cat(
+ [self.unfold(prognostic_outputs), self.unfold(outputs_only)],
+ dim=self.channel_dim,
+ )
+
+ return reshaped
+
+ def reset(self):
+ """Resets the state of the network"""
+ self.encoder.reset()
+ self.decoder.reset()
diff --git a/fme/fme/ace/models/makani/sfnonet.py b/fme/fme/ace/models/makani/sfnonet.py
index 4e33ec9..18367f6 100644
--- a/fme/fme/ace/models/makani/sfnonet.py
+++ b/fme/fme/ace/models/makani/sfnonet.py
@@ -17,6 +17,7 @@
import math
from functools import partial
+from typing import Tuple
# for annotation of models
import torch
@@ -249,8 +250,8 @@ def __init__(
sht_grid_type="legendre-gauss",
filter_type="linear",
operator_type="dhconv",
- inp_shape=(721, 1440),
- out_shape=(721, 1440),
+ inp_shape: Tuple[int, int] = (721, 1440),
+ out_shape: Tuple[int, int] = (721, 1440),
scale_factor=8,
inp_chans=2,
out_chans=2,
@@ -424,9 +425,7 @@ def __init__(
torch.zeros(1, embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1])
)
# information about how tensors are shared / sharded across ranks
- self.pos_embed.is_shared_mp = (
- []
- ) # no reduction required since pos_embed is already serial
+ self.pos_embed.is_shared_mp = [] # no reduction required since pos_embed is already serial
self.pos_embed.sharded_dims_mp = [None, None, "h", "w"]
self.pos_embed.type = "direct"
with torch.no_grad():
@@ -500,10 +499,10 @@ def _init_spectral_transforms(
ifft_handle = InverseRealFFT2
self.trans_down = fft_handle(
- *self.inp_shape, lmax=modes_lat, mmax=modes_lon
+ self.inp_shape[0], self.inp_shape[1], lmax=modes_lat, mmax=modes_lon
).float()
self.itrans_up = ifft_handle(
- *self.out_shape, lmax=modes_lat, mmax=modes_lon
+ self.out_shape[0], self.out_shape[1], lmax=modes_lat, mmax=modes_lon
).float()
self.trans = fft_handle(
self.h, self.w, lmax=modes_lat, mmax=modes_lon
diff --git a/fme/fme/ace/models/modulus/layers.py b/fme/fme/ace/models/modulus/layers.py
index 22eddfe..81ec16b 100644
--- a/fme/fme/ace/models/modulus/layers.py
+++ b/fme/fme/ace/models/modulus/layers.py
@@ -22,10 +22,9 @@
import torch.nn.functional as F
from torch.cuda import amp
from torch.utils.checkpoint import checkpoint
-from torch_harmonics import *
-from .activations import *
-from .contractions import *
+from .activations import ComplexReLU
+from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd
@torch.jit.script
@@ -199,185 +198,6 @@ def forward(self, x): # pragma: no cover
return out
-class SpectralConv2d(nn.Module):
- """
- Spectral Convolution as utilized in
- """
-
- def __init__(
- self,
- forward_transform,
- inverse_transform,
- hidden_size,
- sparsity_threshold=0.0,
- hard_thresholding_fraction=1,
- use_complex_kernels=False,
- compression=None,
- rank=0,
- bias=False,
- ): # pragma: no cover
- super(SpectralConv2d, self).__init__()
-
- self.hidden_size = hidden_size
- self.sparsity_threshold = sparsity_threshold
- self.hard_thresholding_fraction = hard_thresholding_fraction
- self.scale = 1 / hidden_size**2
- self.contract_handle = (
- compl_contract2d_fwd_c if use_complex_kernels else compl_contract2d_fwd
- )
-
- self.forward_transform = forward_transform
- self.inverse_transform = inverse_transform
-
- self.output_dims = (self.inverse_transform.nlat, self.inverse_transform.nlon)
- modes_lat = self.inverse_transform.lmax
- modes_lon = self.inverse_transform.mmax
- self.modes_lat = int(modes_lat * self.hard_thresholding_fraction)
- self.modes_lon = int(modes_lon * self.hard_thresholding_fraction)
-
- # new simple linear layer
- self.w = nn.Parameter(
- self.scale
- * torch.randn(
- self.hidden_size, self.hidden_size, self.modes_lat, self.modes_lon, 2
- )
- )
- # optional bias
- if bias:
- self.b = nn.Parameter(
- self.scale * torch.randn(1, self.hidden_size, *self.output_dims)
- )
-
- def forward(self, x): # pragma: no cover
- dtype = x.dtype
- # x = x.float()
- B, C, H, W = x.shape
-
- with amp.autocast(enabled=False):
- x = x.to(torch.float32)
- x = self.forward_transform(x)
- x = torch.view_as_real(x)
- x = x.to(dtype)
-
- # do spectral conv
- modes = torch.zeros(x.shape, device=x.device)
-
- # modes[:, :, :self.modes_lat, :self.modes_lon, :] = self.contract_handle(x[:, :, :self.modes_lat, :self.modes_lon, :], self.wh)
- # modes[:, :, -self.modes_lat:, :self.modes_lon, :] = self.contract_handle(x[:, :, -self.modes_lat:, :self.modes_lon, :], self.wl)
- modes = self.contract_handle(x, self.w)
-
- # finalize
- x = F.softshrink(modes, lambd=self.sparsity_threshold)
- x = torch.view_as_complex(x)
-
- with amp.autocast(enabled=False):
- x = x.to(torch.float32)
- x = torch.view_as_complex(x)
- x = x.contiguous()
- x = self.inverse_transform(x)
- x = x.to(dtype)
-
- if hasattr(self, "b"):
- x = x + self.b
-
- return x
-
-
-class SpectralConvS2(nn.Module):
- """
- Spectral Convolution as utilized in
- """
-
- def __init__(
- self,
- forward_transform,
- inverse_transform,
- hidden_size,
- sparsity_threshold=0.0,
- use_complex_kernels=False,
- compression=None,
- rank=128,
- bias=False,
- ): # pragma: no cover
- super(SpectralConvS2, self).__init__()
-
- self.hidden_size = hidden_size
- self.sparsity_threshold = sparsity_threshold
- self.scale = 0.02
-
- self.forward_transform = forward_transform
- self.inverse_transform = inverse_transform
-
- self.modes_lat = self.forward_transform.lmax
- self.modes_lon = self.forward_transform.mmax
-
- assert self.inverse_transform.lmax == self.modes_lat
- assert self.inverse_transform.mmax == self.modes_lon
-
- # remember the lower triangular indices
- ii, jj = torch.tril_indices(self.modes_lat, self.modes_lon)
- self.register_buffer("ii", ii)
- self.register_buffer("jj", jj)
-
- if compression == "tt":
- self.rank = rank
- # tensortrain coefficients
- g1 = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.rank, 2))
- g2 = nn.Parameter(
- self.scale * torch.randn(self.rank, self.hidden_size, self.rank, 2)
- )
- g3 = nn.Parameter(self.scale * torch.randn(self.rank, len(ii), 2))
- self.w = nn.ParameterList([g1, g2, g3])
-
- self.contract_handle = (
- contract_tt # if use_complex_kernels else raise(NotImplementedError)
- )
- else:
- self.w = nn.Parameter(
- self.scale * torch.randn(self.hidden_size, self.hidden_size, len(ii), 2)
- )
- self.contract_handle = (
- compl_contract_fwd_c if use_complex_kernels else compl_contract_fwd
- )
-
- if bias:
- self.b = nn.Parameter(
- self.scale * torch.randn(1, self.hidden_size, *self.output_dims)
- )
-
- def forward(self, x): # pragma: no cover
- dtype = x.dtype
- # x = x.float()
- B, C, H, W = x.shape
-
- with amp.autocast(enabled=False):
- x = x.to(torch.float32)
- x = x.contiguous()
- x = self.forward_transform(x)
- x = torch.view_as_real(x)
- x = x.to(dtype)
-
- # do spectral conv
- modes = torch.zeros(x.shape, device=x.device)
- modes[:, :, self.ii, self.jj, :] = self.contract_handle(
- x[:, :, self.ii, self.jj, :], self.w
- )
-
- # finalize
- x = F.softshrink(modes, lambd=self.sparsity_threshold)
-
- with amp.autocast(enabled=False):
- x = x.to(torch.float32)
- x = torch.view_as_complex(x)
- x = self.inverse_transform(x)
- x = x.to(dtype)
-
- if hasattr(self, "b"):
- x = x + self.b
-
- return x
-
-
class SpectralAttention2d(nn.Module):
"""
2d Spectral Attention layer
@@ -404,10 +224,10 @@ def __init__(
self.hidden_size = int(hidden_size_factor * self.embed_dim)
self.scale = 0.02
self.spectral_layers = spectral_layers
- self.mul_add_handle = (
- compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd
- )
- self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd
+ if use_complex_kernels:
+ raise NotImplementedError("complex kernels not supported")
+ self.mul_add_handle = compl_muladd2d_fwd
+ self.mul_handle = compl_mul2d_fwd
self.modes_lat = forward_transform.lmax
self.modes_lon = forward_transform.mmax
@@ -512,12 +332,12 @@ def __init__(
self.sparsity_threshold = sparsity_threshold
self.hidden_size = int(hidden_size_factor * self.embed_dim)
self.scale = 0.02
+ if use_complex_kernels:
+ raise NotImplementedError("complex kernels not supported")
# self.mul_add_handle = compl_muladd1d_fwd_c if use_complex_kernels else compl_muladd1d_fwd
- self.mul_add_handle = (
- compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd
- )
+ self.mul_add_handle = compl_muladd2d_fwd
# self.mul_handle = compl_mul1d_fwd_c if use_complex_kernels else compl_mul1d_fwd
- self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd
+ self.mul_handle = compl_mul2d_fwd
self.spectral_layers = spectral_layers
self.modes_lat = forward_transform.lmax
diff --git a/fme/fme/ace/models/modulus/sfnonet.py b/fme/fme/ace/models/modulus/sfnonet.py
index 0c0434b..1f8806c 100644
--- a/fme/fme/ace/models/modulus/sfnonet.py
+++ b/fme/fme/ace/models/modulus/sfnonet.py
@@ -232,9 +232,9 @@ def forward(self, x):
x = self.act_layer(x)
x_norm = torch.zeros_like(x)
- x_norm[
- ..., : self.output_shape_loc[0], : self.output_shape_loc[1]
- ] = self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]])
+ x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = (
+ self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]])
+ )
x = x_norm
if hasattr(self, "mlp"):
diff --git a/fme/fme/ace/registry/__init__.py b/fme/fme/ace/registry/__init__.py
index 12c781f..31b7898 100644
--- a/fme/fme/ace/registry/__init__.py
+++ b/fme/fme/ace/registry/__init__.py
@@ -1,6 +1,6 @@
# import modules so they are registered
from . import prebuilt as _prebuilt
from . import sfno as _sfno
-from .registry import ModuleSelector, get_from_registry, register
+from .registry import ModuleSelector
del _prebuilt, _sfno
diff --git a/fme/fme/ace/registry/hpx.py b/fme/fme/ace/registry/hpx.py
new file mode 100644
index 0000000..e90f42c
--- /dev/null
+++ b/fme/fme/ace/registry/hpx.py
@@ -0,0 +1,78 @@
+import dataclasses
+from typing import Tuple
+
+import torch.nn as nn
+
+from fme.ace.models.healpix.healpix_decoder import UNetDecoderConfig
+from fme.ace.models.healpix.healpix_encoder import UNetEncoderConfig
+from fme.ace.models.healpix.healpix_recunet import HEALPixRecUNet
+from fme.ace.registry.registry import ModuleConfig, ModuleSelector
+
+
+@ModuleSelector.register("HEALPixRecUNet")
+@dataclasses.dataclass
+class HEALPixRecUNetBuilder(ModuleConfig):
+ """
+ Configuration for the HEALPixRecUNet architecture used in DLWP.
+
+ Parameters:
+ presteps: Number of pre-steps, by default 1.
+ input_time_size: Input time dimension, by default 0.
+ output_time_size: Output time dimension, by default 0.
+ delta_time: Delta time interval, by default "6h".
+ reset_cycle: Reset cycle interval, by default "24h".
+ input_channels: Number of input channels, by default 8.
+ output_channels: Number of output channels, by default 8.
+ n_constants: Number of constant input channels, by default 2.
+ decoder_input_channels: Number of input channels for the decoder, by default 1.
+ enable_nhwc: Flag to enable NHWC data format, by default False.
+ enable_healpixpad: Flag to enable HEALPix padding, by default False.
+ """
+
+ encoder: UNetEncoderConfig
+ decoder: UNetDecoderConfig
+ presteps: int = 1
+ input_time_size: int = 0
+ output_time_size: int = 0
+ delta_time: str = "6h"
+ reset_cycle: str = "24h"
+ n_constants: int = 2
+ decoder_input_channels: int = 1
+ prognostic_variables: int = 7
+ enable_nhwc: bool = False
+ enable_healpixpad: bool = False
+
+ def build(
+ self,
+ n_in_channels: int,
+ n_out_channels: int,
+ img_shape: Tuple[int, int],
+ ) -> nn.Module:
+ """
+ Builds the HEALPixRecUNet model.
+
+ Args:
+ n_in_channels: Number of input channels.
+ n_out_channels: Number of output channels.
+ img_shape: Shape of the input image.
+
+ Returns:
+ HEALPixRecUNet model.
+ """
+ # Construct the HEALPixRecUNet module here using the parameters
+ return HEALPixRecUNet(
+ encoder=self.encoder,
+ decoder=self.decoder,
+ input_channels=n_in_channels,
+ output_channels=n_out_channels,
+ prognostic_variables=self.prognostic_variables,
+ n_constants=self.n_constants,
+ decoder_input_channels=self.decoder_input_channels,
+ input_time_size=self.input_time_size,
+ output_time_size=self.output_time_size,
+ delta_time=self.delta_time,
+ reset_cycle=self.reset_cycle,
+ presteps=self.presteps,
+ enable_nhwc=self.enable_nhwc,
+ enable_healpixpad=self.enable_healpixpad,
+ )
diff --git a/fme/fme/ace/registry/prebuilt.py b/fme/fme/ace/registry/prebuilt.py
index f8bebdc..880df5e 100644
--- a/fme/fme/ace/registry/prebuilt.py
+++ b/fme/fme/ace/registry/prebuilt.py
@@ -3,10 +3,10 @@
from torch import nn
-from fme.ace.registry.registry import ModuleConfig, register
+from fme.ace.registry.registry import ModuleConfig, ModuleSelector
-@register("prebuilt")
+@ModuleSelector.register("prebuilt")
@dataclasses.dataclass
class PreBuiltBuilder(ModuleConfig):
"""
diff --git a/fme/fme/ace/registry/registry.py b/fme/fme/ace/registry/registry.py
index 2486c8b..3fa9bd1 100644
--- a/fme/fme/ace/registry/registry.py
+++ b/fme/fme/ace/registry/registry.py
@@ -1,6 +1,4 @@
-from fme.core.registry import ( # noqa: F401
+from fme.core.registry.module import ( # noqa: F401
ModuleConfig,
ModuleSelector,
- get_from_registry,
- register,
)
diff --git a/fme/fme/ace/registry/sfno.py b/fme/fme/ace/registry/sfno.py
index 0031ea7..f915055 100644
--- a/fme/fme/ace/registry/sfno.py
+++ b/fme/fme/ace/registry/sfno.py
@@ -5,12 +5,12 @@
SphericalFourierNeuralOperatorNet as MakaniSFNO,
)
from fme.ace.models.modulus.sfnonet import SphericalFourierNeuralOperatorNet
-from fme.ace.registry.registry import ModuleConfig, register
+from fme.ace.registry.registry import ModuleConfig, ModuleSelector
# this is based on the call signature of SphericalFourierNeuralOperatorNet at
# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501
-@register("SphericalFourierNeuralOperatorNet")
+@ModuleSelector.register("SphericalFourierNeuralOperatorNet")
@dataclasses.dataclass
class SphericalFourierNeuralOperatorBuilder(ModuleConfig):
"""
@@ -55,7 +55,7 @@ def build(
return sfno_net
-@register("SFNO-v0.1.0")
+@ModuleSelector.register("SFNO-v0.1.0")
@dataclasses.dataclass
class SFNO_V0_1_0(ModuleConfig):
"""
diff --git a/fme/fme/ace/registry/test_hpx.py b/fme/fme/ace/registry/test_hpx.py
new file mode 100644
index 0000000..9a77522
--- /dev/null
+++ b/fme/fme/ace/registry/test_hpx.py
@@ -0,0 +1,830 @@
+import dataclasses
+import datetime
+import logging
+from typing import Tuple, Union
+
+import numpy as np
+import pytest
+import torch as th
+
+from fme.ace.models.healpix.healpix_activations import (
+ CappedGELUConfig,
+ DownsamplingBlockConfig,
+)
+from fme.ace.models.healpix.healpix_blocks import ConvBlockConfig, RecurrentBlockConfig
+from fme.ace.models.healpix.healpix_decoder import UNetDecoder
+from fme.ace.models.healpix.healpix_encoder import UNetEncoder
+from fme.ace.models.healpix.healpix_recunet import HEALPixRecUNet
+from fme.ace.registry.hpx import UNetDecoderConfig, UNetEncoderConfig
+from fme.ace.stepper import SingleModuleStepperConfig
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
+from fme.core.normalizer import NormalizationConfig
+
+TIMESTEP = datetime.timedelta(hours=6)
+logger = logging.getLogger("__name__")
+
+
+def fix_random_seeds(seed=0):
+ """Fix random seeds for reproducibility"""
+ np.random.seed(seed)
+ th.manual_seed(seed)
+ th.cuda.manual_seed(seed)
+
+
+def conv_next_block_config(in_channels=3, out_channels=1):
+ activation_block_config = CappedGELUConfig(cap_value=10)
+ conv_next_block_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ activation=activation_block_config,
+ kernel_size=3,
+ dilation=1,
+ upscale_factor=4,
+ block_type="ConvNeXtBlock",
+ )
+ return conv_next_block_config
+
+
+def down_sampling_block_config():
+ return DownsamplingBlockConfig(pooling=2, block_type="AvgPool")
+
+
+def encoder_config(
+ conv_next_block_config, down_sampling_block_config, n_channels=[136, 68, 34]
+):
+ return UNetEncoderConfig(
+ conv_block=conv_next_block_config,
+ down_sampling_block=down_sampling_block_config,
+ n_channels=n_channels,
+ dilations=[1, 2, 4],
+ )
+
+
+def up_sampling_block_config(in_channels=3, out_channels=1):
+ activation_block_config = CappedGELUConfig(cap_value=10)
+ transposed_conv_upsample_block_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ activation=activation_block_config,
+ upsampling=2,
+ block_type="TransposedConvUpsample",
+ )
+ return transposed_conv_upsample_block_config
+
+
+def output_layer_config(in_channels=3, out_channels=2):
+ conv_block_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ dilation=1,
+ n_layers=1,
+ block_type="BasicConvBlock",
+ )
+ return conv_block_config
+
+
+def recurrent_block_config(in_channels=3):
+ recurrent_block_config = RecurrentBlockConfig(
+ in_channels=in_channels,
+ kernel_size=1,
+ block_type="ConvGRUBlock",
+ )
+ return recurrent_block_config
+
+
+def decoder_config(
+ conv_next_block_config,
+ up_sampling_block_config,
+ output_layer_config,
+ recurrent_block_config,
+ n_channels=[34, 68, 136],
+):
+ decoder_config = UNetDecoderConfig(
+ conv_block=conv_next_block_config,
+ up_sampling_block=up_sampling_block_config,
+ recurrent_block=recurrent_block_config,
+ output_layer=output_layer_config,
+ n_channels=n_channels,
+ dilations=[4, 2, 1],
+ )
+ return decoder_config
+
+
+def _test_data():
+ # create dummy data
+ def generate_test_data(batch_size=8, time_dim=1, channels=7, img_size=16):
+ device = get_device()
+ test_data = th.randn(batch_size, 12, time_dim * channels, img_size, img_size)
+ return test_data.to(device)
+
+ return generate_test_data
+
+
+def constant_data():
+ # create dummy data
+ def generate_constant_data(channels=2, img_size=16):
+ device = get_device()
+ constants = th.randn(12, channels, img_size, img_size)
+
+ return constants.to(device)
+
+ return generate_constant_data
+
+
+def insolation_data():
+ # create dummy data
+ def generate_insolation_data(batch_size=8, time_dim=1, img_size=16):
+ device = get_device()
+ insolation = th.randn(batch_size, 12, time_dim, img_size, img_size)
+
+ return insolation.to(device)
+
+ return generate_insolation_data
+
+
+@pytest.mark.parametrize(
+ "shape",
+ [
+ pytest.param((8, 16)),
+ ],
+)
+def test_hpx_init(shape):
+ in_channels = 7
+ out_channels = 7
+ prognostic_variables = min(in_channels, out_channels)
+ n_constants = 1
+ decoder_input_channels = 1
+ input_time_size = 2
+ output_time_size = 4
+ device = get_device()
+
+ conv_next_block = conv_next_block_config()
+ down_sampling_block = down_sampling_block_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(conv_next_block, down_sampling_block)
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ decoder = decoder_config(
+ conv_next_block, up_sampling_block, output_layer, recurrent_block
+ )
+
+ hpx_config_data = {
+ "type": "HEALPixRecUNet",
+ "config": {
+ "encoder": dataclasses.asdict(encoder),
+ "decoder": dataclasses.asdict(decoder),
+ "prognostic_variables": prognostic_variables,
+ "n_constants": n_constants,
+ "decoder_input_channels": decoder_input_channels,
+ "input_time_size": input_time_size,
+ "output_time_size": output_time_size,
+ },
+ }
+
+ stepper_config_data = {
+ "builder": hpx_config_data,
+ "in_names": ["x"],
+ "out_names": ["x"],
+ "normalization": dataclasses.asdict(
+ NormalizationConfig(
+ means={"x": float(np.random.randn(1).item())},
+ stds={"x": float(np.random.randn(1).item())},
+ )
+ ),
+ }
+ area = th.ones((1, 16, 32)).to(device)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=th.arange(7), bk=th.arange(7)
+ ).to(device)
+ stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data)
+ stepper = stepper_config.get_stepper(
+ img_shape=shape,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
+ timestep=TIMESTEP,
+ )
+ assert type(stepper.module.module) is HEALPixRecUNet
+
+
+@pytest.mark.parametrize(
+ "in_channels, out_channels, n_constants, decoder_input_channels, input_time_size, \
+ output_time_size, couplings, expected_exception, expected_message",
+ [
+ (7, 7, 1, 1, 2, 4, None, None, None), # Valid case
+ (
+ 7,
+ 7,
+ 1,
+ 1,
+ 2,
+ 3,
+ None,
+ ValueError,
+ "'output_time_size' must be a multiple of 'input_time_size'",
+ ), # Bad input and output time dims
+ (
+ 7,
+ 7,
+ 0,
+ 2,
+ 2,
+ 3,
+ ["t2m", "v10m"],
+ NotImplementedError,
+ "support for coupled models with no constant field",
+ ), # Couplings with no constants
+ (
+ 7,
+ 7,
+ 2,
+ 0,
+ 2,
+ 3,
+ ["t2m", "v10m"],
+ NotImplementedError,
+ "support for coupled models with no decoder",
+ ), # Couplings with no decoder input channels
+ (
+ 7,
+ 7,
+ 0,
+ 0,
+ 2,
+ 3,
+ None,
+ ValueError,
+ "'output_time_size' must be a multiple of 'input_time_size'",
+ ), # No constant fields and no decoder
+ ],
+)
+def test_HEALPixRecUNet_initialize(
+ in_channels,
+ out_channels,
+ n_constants,
+ decoder_input_channels,
+ input_time_size,
+ output_time_size,
+ couplings,
+ expected_exception,
+ expected_message,
+):
+ prognostic_variables = min(out_channels, in_channels)
+ conv_next_block = conv_next_block_config()
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(conv_next_block, down_sampling_block_config())
+ decoder = decoder_config(
+ conv_next_block, up_sampling_block, output_layer, recurrent_block
+ )
+ device = get_device()
+
+ if expected_exception:
+ with pytest.raises(expected_exception, match=expected_message):
+ model = HEALPixRecUNet(
+ encoder=encoder,
+ decoder=decoder,
+ input_channels=in_channels,
+ output_channels=out_channels,
+ prognostic_variables=prognostic_variables,
+ n_constants=n_constants,
+ decoder_input_channels=decoder_input_channels,
+ input_time_size=input_time_size,
+ output_time_size=output_time_size,
+ couplings=couplings,
+ ).to(device)
+ else:
+ model = HEALPixRecUNet(
+ encoder=encoder,
+ decoder=decoder,
+ input_channels=in_channels,
+ output_channels=out_channels,
+ prognostic_variables=prognostic_variables,
+ n_constants=n_constants,
+ decoder_input_channels=decoder_input_channels,
+ input_time_size=input_time_size,
+ output_time_size=output_time_size,
+ couplings=couplings,
+ ).to(device)
+ assert isinstance(model, HEALPixRecUNet)
+
+
+def test_HEALPixRecUNet_integration_steps():
+ in_channels = 2
+ out_channels = 2
+ prognostic_variables = min(out_channels, in_channels)
+ n_constants = 1
+ decoder_input_channels = 0
+ input_time_size = 2
+ output_time_size = 4
+ device = get_device()
+
+ conv_next_block = conv_next_block_config()
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(conv_next_block, down_sampling_block_config())
+ decoder = decoder_config(
+ conv_next_block, up_sampling_block, output_layer, recurrent_block
+ )
+
+ model = HEALPixRecUNet(
+ encoder=encoder,
+ decoder=decoder,
+ input_channels=in_channels,
+ output_channels=out_channels,
+ prognostic_variables=prognostic_variables,
+ n_constants=n_constants,
+ decoder_input_channels=decoder_input_channels,
+ input_time_size=input_time_size,
+ output_time_size=output_time_size,
+ ).to(device)
+
+ assert model.integration_steps == output_time_size // input_time_size
+
+
+def test_HEALPixRecUNet_reset(very_fast_only: bool):
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
+ # create a smaller version of the dlwp healpix model
+ in_channels = 3
+ out_channels = 3
+ prognostic_variables = min(out_channels, in_channels)
+ n_constants = 2
+ decoder_input_channels = 1
+ input_time_size = 2
+ output_time_size = 4
+ size = 16
+ device = get_device()
+
+ conv_next_block = conv_next_block_config()
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(conv_next_block, down_sampling_block_config())
+ decoder = decoder_config(
+ conv_next_block, up_sampling_block, output_layer, recurrent_block
+ )
+
+ fix_random_seeds(seed=42)
+ x = _test_data()(time_dim=input_time_size, channels=in_channels, img_size=size)
+ decoder_inputs = insolation_data()(time_dim=input_time_size, img_size=size)
+ constants = constant_data()(channels=n_constants, img_size=size)
+ batch_size = x.shape[0]
+ constants = constants.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+ inputs = th.concat(
+ (x, decoder_inputs, constants), dim=-3
+ ) # [x, decoder_inputs, constants]
+
+ model = HEALPixRecUNet(
+ encoder=encoder,
+ decoder=decoder,
+ input_channels=in_channels,
+ output_channels=out_channels,
+ prognostic_variables=prognostic_variables,
+ n_constants=n_constants,
+ decoder_input_channels=decoder_input_channels,
+ input_time_size=input_time_size,
+ output_time_size=output_time_size,
+ enable_healpixpad=False,
+ delta_time="6h",
+ ).to(device)
+
+ out_var = model(inputs)
+ model.reset()
+
+ assert compare_output(out_var, model(inputs))
+
+
+# Checks the model can perform a forward class on various input configurations
+# [full inputs, no decoder inputs, no constant inputs]
+@pytest.mark.parametrize(
+ "inputs_config, in_channels, decoder_input_channels, \
+ out_channels, input_time_size, output_time_size, n_constants, size",
+ [
+ ([0, 1, 2], 3, 1, 3, 2, 4, 2, 16), # full inputs
+ ([0, 2], 3, 0, 3, 2, 4, 2, 16), # no decoder inputs
+ ([0, 1], 3, 1, 3, 2, 4, 0, 16), # no constant inputs
+ ],
+)
+def test_HEALPixRecUNet_forward(
+ inputs_config,
+ in_channels,
+ decoder_input_channels,
+ out_channels,
+ input_time_size,
+ output_time_size,
+ n_constants,
+ size,
+ very_fast_only: bool,
+):
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
+ prognostic_variables = min(out_channels, in_channels)
+ device = get_device()
+ conv_next_block = conv_next_block_config()
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(conv_next_block, down_sampling_block_config())
+ decoder = decoder_config(
+ conv_next_block, up_sampling_block, output_layer, recurrent_block
+ )
+
+ fix_random_seeds(seed=42)
+ x = _test_data()(time_dim=input_time_size, channels=in_channels, img_size=size)
+ batch_size = x.shape[0]
+
+ if decoder_input_channels > 0:
+ decoder_inputs = insolation_data()(time_dim=input_time_size, img_size=size)
+ else:
+ decoder_inputs = insolation_data()(time_dim=0, img_size=size)
+ constants = constant_data()(channels=n_constants, img_size=size)
+ constants = constants.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+
+ all_inputs = [x, decoder_inputs, constants]
+ inputs = th.concat(all_inputs, dim=-3)
+
+ model = HEALPixRecUNet(
+ encoder=encoder,
+ decoder=decoder,
+ input_channels=in_channels,
+ output_channels=out_channels,
+ prognostic_variables=prognostic_variables,
+ n_constants=n_constants,
+ decoder_input_channels=decoder_input_channels,
+ input_time_size=input_time_size,
+ output_time_size=output_time_size,
+ enable_healpixpad=False,
+ delta_time="6h",
+ ).to(device)
+ model(inputs)
+
+
+# pragma mark - encoder
+
+
+def test_UNetEncoder_initialize():
+ device = get_device()
+ channels = 2
+ n_channels = (16, 32, 64)
+
+ # Dicts for block configs used by encoder
+ conv_block_config = ConvBlockConfig(
+ in_channels=channels,
+ block_type="ConvNeXtBlock",
+ )
+ down_sampling_block_config = DownsamplingBlockConfig(
+ pooling=2, block_type="MaxPool"
+ )
+
+ encoder = UNetEncoder(
+ conv_block=conv_block_config,
+ down_sampling_block=down_sampling_block_config,
+ n_channels=n_channels,
+ input_channels=channels,
+ ).to(device)
+ assert isinstance(encoder, UNetEncoder)
+
+ # with dilations
+ encoder = UNetEncoder(
+ conv_block=conv_block_config,
+ down_sampling_block=down_sampling_block_config,
+ n_channels=n_channels,
+ input_channels=channels,
+ dilations=[1, 1, 1],
+ ).to(device)
+ assert isinstance(encoder, UNetEncoder)
+
+
+def test_UNetEncoder_forward():
+ channels = 2
+ hw_size = 16
+ b_size = 12
+ n_channels = (16, 32, 64)
+ device = get_device()
+
+ # block configs used by encoder
+ conv_block_config = ConvBlockConfig(
+ in_channels=channels,
+ block_type="ConvNeXtBlock",
+ )
+ down_sampling_block_config = DownsamplingBlockConfig(
+ pooling=2, block_type="MaxPool"
+ )
+ encoder = UNetEncoder(
+ conv_block=conv_block_config,
+ down_sampling_block=down_sampling_block_config,
+ n_channels=n_channels,
+ input_channels=channels,
+ ).to(device)
+
+ tensor_size = [b_size, channels, hw_size, hw_size]
+ invar = th.rand(tensor_size).to(device)
+ outvar = encoder(invar)
+
+ # doesn't do anything
+ encoder.reset()
+
+ # outvar is a module list
+ for idx, out_tensor in enumerate(outvar):
+ # verify the channels and h dim are correct
+ assert out_tensor.shape[1] == n_channels[idx]
+ # default behaviour is to half the h/w size after first
+ assert out_tensor.shape[2] == tensor_size[2] // (2**idx)
+
+
+def test_UNetEncoder_reset():
+ channels = 2
+ n_channels = (16, 32, 64)
+ device = get_device()
+
+ # Dicts for block configs used by encoder
+ conv_block_config = ConvBlockConfig(
+ in_channels=channels,
+ block_type="ConvNeXtBlock",
+ )
+ down_sampling_block_config = DownsamplingBlockConfig(
+ pooling=2,
+ block_type="MaxPool",
+ )
+ encoder = UNetEncoder(
+ conv_block=conv_block_config,
+ down_sampling_block=down_sampling_block_config,
+ n_channels=n_channels,
+ input_channels=channels,
+ ).to(device)
+
+ # doesn't do anything
+ encoder.reset()
+ assert isinstance(encoder, UNetEncoder)
+
+
+def test_UNetDecoder_initilization():
+ in_channels = 2
+ out_channels = 1
+ n_channels = (64, 32, 16)
+ device = get_device()
+
+ # Dicts for block configs used by decoder
+ conv_block_config = ConvBlockConfig(
+ in_channels=in_channels, out_channels=out_channels, block_type="ConvNeXtBlock"
+ )
+ up_sampling_block_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ upsampling=2,
+ block_type="TransposedConvUpsample",
+ )
+
+ output_layer_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ dilation=1,
+ n_layers=1,
+ block_type="ConvNeXtBlock",
+ )
+
+ recurrent_block_config = RecurrentBlockConfig(
+ in_channels=2,
+ kernel_size=1,
+ block_type="ConvGRUBlock",
+ )
+
+ decoder = UNetDecoder(
+ conv_block=conv_block_config,
+ up_sampling_block=up_sampling_block_config,
+ output_layer=output_layer_config,
+ recurrent_block=recurrent_block_config,
+ n_channels=n_channels,
+ ).to(device)
+
+ assert isinstance(decoder, UNetDecoder)
+
+ # without the recurrent block and with dilations
+ decoder = UNetDecoder(
+ conv_block=conv_block_config,
+ up_sampling_block=up_sampling_block_config,
+ output_layer=output_layer_config,
+ recurrent_block=None,
+ n_channels=n_channels,
+ dilations=[1, 1, 1],
+ ).to(device)
+ assert isinstance(decoder, UNetDecoder)
+
+
+def test_UNetDecoder_forward():
+ in_channels = 2
+ out_channels = 1
+ hw_size = 32
+ b_size = 12
+ n_channels = (64, 32, 16)
+ device = get_device()
+
+ # Dicts for block configs used by decoder
+ conv_block_config = ConvBlockConfig(
+ in_channels=in_channels, out_channels=out_channels, block_type="ConvNeXtBlock"
+ )
+ up_sampling_block_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ upsampling=2,
+ block_type="TransposedConvUpsample",
+ )
+ output_layer_config = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ dilation=1,
+ n_layers=1,
+ block_type="BasicConvBlock",
+ )
+ recurrent_block_config = RecurrentBlockConfig(
+ in_channels=2,
+ kernel_size=1,
+ block_type="ConvGRUBlock",
+ )
+
+ decoder = UNetDecoder(
+ conv_block=conv_block_config,
+ up_sampling_block=up_sampling_block_config,
+ output_layer=output_layer_config,
+ recurrent_block=recurrent_block_config,
+ n_channels=n_channels,
+ ).to(device)
+
+ output_2_size = th.Size([b_size, out_channels, hw_size, hw_size])
+
+ # build the list of tensors for the decoder
+ invars = []
+ # decoder has an algorithm that goes back to front
+ for idx in range(len(n_channels) - 1, -1, -1):
+ tensor_size = [b_size, n_channels[idx], hw_size, hw_size]
+ invars.append(th.rand(tensor_size).to(device))
+ hw_size = hw_size // 2
+
+ outvar = decoder(invars)
+ assert outvar.shape == output_2_size
+
+ # make sure history is taken into account with ConvGRU
+ outvar_hist = decoder(invars)
+ assert not compare_output(outvar, outvar_hist)
+
+ # check with no recurrent
+ decoder = UNetDecoder(
+ conv_block=conv_block_config,
+ up_sampling_block=up_sampling_block_config,
+ output_layer=output_layer_config,
+ recurrent_block=None,
+ n_channels=n_channels,
+ dilations=[1, 1, 1],
+ ).to(device)
+
+ outvar = decoder(invars)
+ assert outvar.shape == output_2_size
+
+
+def test_UNetDecoder_reset():
+ in_channels = 2
+ out_channels = 1
+ hw_size = 32
+ b_size = 12
+ n_channels = (64, 32, 16)
+ device = get_device()
+
+ # Dicts for block configs used by decoder
+ conv_block = ConvBlockConfig(in_channels=in_channels, block_type="ConvNeXtBlock")
+ up_sampling_block = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ block_type="TransposedConvUpsample",
+ )
+ output_layer = ConvBlockConfig(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ dilation=1,
+ n_layers=1,
+ block_type="BasicConvBlock",
+ )
+
+ recurrent_block = RecurrentBlockConfig(
+ in_channels=2, kernel_size=1, block_type="ConvLSTMBlock"
+ )
+
+ decoder = UNetDecoder(
+ conv_block=conv_block,
+ up_sampling_block=up_sampling_block,
+ output_layer=output_layer,
+ recurrent_block=recurrent_block,
+ n_channels=n_channels,
+ ).to(device)
+
+ # build the list of tensors for the decoder
+ invars = []
+ # decoder has an algorithm that goes back to front
+ for idx in range(len(n_channels) - 1, -1, -1):
+ tensor_size = [b_size, n_channels[idx], hw_size, hw_size]
+ invars.append(th.rand(tensor_size).to(device))
+ hw_size = hw_size // 2
+
+ outvar = decoder(invars)
+
+ # make sure history is taken into account with ConvGRU
+ outvar_hist = decoder(invars)
+ assert not compare_output(outvar, outvar_hist)
+
+ # make sure after reset we get the same result
+ decoder.reset()
+ outvar_reset = decoder(invars)
+ assert compare_output(outvar, outvar_reset)
+
+ # test reset without recurrent block
+ decoder = UNetDecoder(
+ conv_block=conv_block,
+ up_sampling_block=up_sampling_block,
+ output_layer=output_layer,
+ recurrent_block=None,
+ n_channels=n_channels,
+ ).to(device)
+
+ outvar = decoder(invars)
+
+ # without the recurrent block should be the same
+ outvar_hist = decoder(invars)
+ assert compare_output(outvar, outvar_hist)
+
+ # make sure after reset we get the same result
+ decoder.reset()
+ outvar_reset = decoder(invars)
+ assert compare_output(outvar, outvar_reset)
+
+
+def compare_output(
+ output_1: Union[th.Tensor, Tuple[th.Tensor, ...]],
+ output_2: Union[th.Tensor, Tuple[th.Tensor, ...]],
+ rtol: float = 1e-5,
+ atol: float = 1e-5,
+) -> bool:
+ """Compares model outputs and returns if they are the same
+
+ Args
+ output_1: First item to compare
+ output_2: Second item to compare
+ rtol: Relative tolerance of error allowed, by default 1e-5
+ atol: Absolute tolerance of error allowed, by default 1e-5
+
+ Returns:
+ If outputs are the same
+ """
+ # Output of tensor
+ if isinstance(output_1, th.Tensor):
+ return th.allclose(output_1, output_2, rtol, atol)
+ # Output of tuple of tensors
+ elif isinstance(output_1, tuple):
+ # Loop through tuple of outputs
+ for i, (out_1, out_2) in enumerate(zip(output_1, output_2)):
+ # If tensor use allclose
+ if isinstance(out_1, th.Tensor):
+ if not th.allclose(out_1, out_2, rtol, atol):
+ logger.warning(f"Failed comparison between outputs {i}")
+ logger.warning(f"Max Difference: {th.amax(th.abs(out_1 - out_2))}")
+ logger.warning(f"Difference: {out_1 - out_2}")
+ return False
+ # Otherwise assume primative
+ else:
+ if not out_1 == out_2:
+ return False
+ elif isinstance(output_1, (list, tuple)) and isinstance(output_2, (list, tuple)):
+ if len(output_1) != len(output_2):
+ print(
+ f"Length mismatch: output_1 {len(output_1)}, output_2 {len(output_2)}"
+ )
+ return False
+ for a, e in zip(output_1, output_2):
+ if not compare_output(a, e):
+ return False
+ return True
+ elif isinstance(output_1, dict) and isinstance(output_2, dict):
+ if output_1.keys() != output_2.keys():
+ print(
+ f"Keys mismatch: output_1 keys {output_1.keys()}, ",
+ f"output_2 keys {output_2.keys()}",
+ )
+ return False
+ for key in output_1:
+ if not compare_output(output_1[key], output_2[key]):
+ return False
+ return True
+ # Unsupported output type
+ else:
+ logger.error(
+ "Model returned invalid type for unit test, \
+ should be th.Tensor or Tuple[th.Tensor]"
+ )
+ return False
+ return True
diff --git a/fme/fme/ace/registry/test_sfno.py b/fme/fme/ace/registry/test_sfno.py
index 046a039..95e7610 100644
--- a/fme/fme/ace/registry/test_sfno.py
+++ b/fme/fme/ace/registry/test_sfno.py
@@ -5,10 +5,11 @@
import pytest
import torch
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.ace.stepper import SingleModuleStepperConfig
+from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.device import get_device
-from fme.core.normalizer import FromStateNormalizer
-from fme.core.stepper import SingleModuleStepperConfig
+from fme.core.gridded_ops import LatLonOperations
+from fme.core.normalizer import NormalizationConfig
TIMESTEP = datetime.timedelta(hours=6)
@@ -34,23 +35,21 @@ def test_sfno_init(shape):
"in_names": ["x"],
"out_names": ["x"],
"normalization": dataclasses.asdict(
- FromStateNormalizer(
- state={
- "means": {"x": float(np.random.randn(1))},
- "stds": {"x": float(np.random.randn(1))},
- }
+ NormalizationConfig(
+ means={"x": float(np.random.randn(1).item())},
+ stds={"x": float(np.random.randn(1).item())},
)
),
}
area = torch.ones((1, 16, 32)).to(get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to(
- get_device()
- )
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ ).to(get_device())
stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data)
stepper = stepper_config.get_stepper(
img_shape=shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
assert len(stepper.module.module.blocks) == num_layers
diff --git a/fme/fme/ace/requirements.py b/fme/fme/ace/requirements.py
new file mode 100644
index 0000000..4c31070
--- /dev/null
+++ b/fme/fme/ace/requirements.py
@@ -0,0 +1,16 @@
+import dataclasses
+from typing import List
+
+
+@dataclasses.dataclass
+class PrognosticStateDataRequirements:
+ """
+ The requirements for the model's prognostic state.
+
+ Parameters:
+ names: Names of prognostic variables.
+ n_timesteps: Number of consecutive timesteps that must be stored.
+ """
+
+ names: List[str]
+ n_timesteps: int
diff --git a/fme/fme/ace/stepper.py b/fme/fme/ace/stepper.py
new file mode 100644
index 0000000..830ae14
--- /dev/null
+++ b/fme/fme/ace/stepper.py
@@ -0,0 +1,955 @@
+import dataclasses
+import datetime
+import logging
+from copy import copy
+from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
+
+import dacite
+import torch
+import xarray as xr
+from torch import nn
+
+from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
+from fme.ace.inference.derived_variables import compute_derived_quantities
+from fme.ace.requirements import PrognosticStateDataRequirements
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.corrector.corrector import CorrectorConfig
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.dataset.utils import decode_timestep, encode_timestep
+from fme.core.device import get_device
+from fme.core.distributed import Distributed
+from fme.core.generics.inference import PredictFunction
+from fme.core.generics.optimization import OptimizationABC
+from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC
+from fme.core.gridded_ops import GriddedOperations, LatLonOperations
+from fme.core.loss import WeightedMappingLossConfig
+from fme.core.normalizer import NormalizationConfig, StandardNormalizer
+from fme.core.ocean import Ocean, OceanConfig
+from fme.core.optimization import NullOptimization
+from fme.core.packer import Packer
+from fme.core.parameter_init import ParameterInitializationConfig
+from fme.core.registry import CorrectorSelector, ModuleSelector
+from fme.core.timing import GlobalTimer
+from fme.core.typing_ import TensorDict, TensorMapping
+
+DEFAULT_TIMESTEP = datetime.timedelta(hours=6)
+DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP)
+
+
+class AtmosphericDeriveFn:
+ def __init__(
+ self,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ):
+ self.vertical_coordinate = vertical_coordinate.to(
+ "cpu"
+ ) # must be on cpu for multiprocessing fork context
+ self.timestep = timestep
+
+ def __call__(self, data: TensorMapping, forcing_data: TensorMapping) -> TensorDict:
+ return compute_derived_quantities(
+ dict(data),
+ vertical_coordinate=self.vertical_coordinate.to(get_device()),
+ timestep=self.timestep,
+ forcing_data=dict(forcing_data),
+ )
+
+
+@dataclasses.dataclass
+class SingleModuleStepperConfig:
+ """
+ Configuration for a single module stepper.
+
+ Parameters:
+ builder: The module builder.
+ in_names: Names of input variables.
+ out_names: Names of output variables.
+ normalization: The normalization configuration.
+ parameter_init: The parameter initialization configuration.
+ ocean: The ocean configuration.
+ loss: The loss configuration.
+ corrector: The corrector configuration.
+ next_step_forcing_names: Names of forcing variables for the next timestep.
+ loss_normalization: The normalization configuration for the loss.
+ residual_normalization: Optional alternative to configure loss normalization.
+ If provided, it will be used for all *prognostic* variables in loss scaling.
+ """
+
+ builder: ModuleSelector
+ in_names: List[str]
+ out_names: List[str]
+ normalization: NormalizationConfig
+ parameter_init: ParameterInitializationConfig = dataclasses.field(
+ default_factory=lambda: ParameterInitializationConfig()
+ )
+ ocean: Optional[OceanConfig] = None
+ loss: WeightedMappingLossConfig = dataclasses.field(
+ default_factory=lambda: WeightedMappingLossConfig()
+ )
+ corrector: Union[CorrectorConfig, CorrectorSelector] = dataclasses.field(
+ default_factory=lambda: CorrectorConfig()
+ )
+ next_step_forcing_names: List[str] = dataclasses.field(default_factory=list)
+ loss_normalization: Optional[NormalizationConfig] = None
+ residual_normalization: Optional[NormalizationConfig] = None
+
+ def __post_init__(self):
+ for name in self.next_step_forcing_names:
+ if name not in self.in_names:
+ raise ValueError(
+ f"next_step_forcing_name '{name}' not in in_names: {self.in_names}"
+ )
+ if name in self.out_names:
+ raise ValueError(
+ f"next_step_forcing_name is an output variable: '{name}'"
+ )
+ if (
+ self.residual_normalization is not None
+ and self.loss_normalization is not None
+ ):
+ raise ValueError(
+ "Only one of residual_normalization, loss_normalization can "
+ "be provided."
+ "If residual_normalization is provided, it will be used for all "
+ "*prognostic* variables in loss scalng. "
+ "If loss_normalization is provided, it will be used for all variables "
+ "in loss scaling."
+ )
+
+ @property
+ def n_ic_timesteps(self) -> int:
+ return 1
+
+ def get_evaluation_window_data_requirements(
+ self, n_forward_steps: int
+ ) -> DataRequirements:
+ return DataRequirements(
+ names=self.all_names,
+ n_timesteps=self._window_steps_required(n_forward_steps),
+ )
+
+ def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements:
+ return PrognosticStateDataRequirements(
+ names=self.prognostic_names,
+ n_timesteps=self.n_ic_timesteps,
+ )
+
+ def get_forcing_window_data_requirements(
+ self, n_forward_steps: int
+ ) -> DataRequirements:
+ if self.ocean is None:
+ names = self.forcing_names
+ else:
+ names = list(set(self.forcing_names).union(self.ocean.forcing_names))
+
+ return DataRequirements(
+ names=names,
+ n_timesteps=self._window_steps_required(n_forward_steps),
+ )
+
+ def _window_steps_required(self, n_forward_steps: int) -> int:
+ return n_forward_steps + self.n_ic_timesteps
+
+ def get_state(self):
+ return dataclasses.asdict(self)
+
+ def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
+ """
+ If the model is being initialized from another model's weights for fine-tuning,
+ returns those weights. Otherwise, returns None.
+
+ The list mirrors the order of `modules` in the `SingleModuleStepper` class.
+ """
+ base_weights = self.parameter_init.get_base_weights()
+ if base_weights is not None:
+ return [base_weights]
+ else:
+ return None
+
+ def get_stepper(
+ self,
+ img_shape: Tuple[int, int],
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ):
+ logging.info("Initializing stepper from provided config")
+ derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep)
+ return SingleModuleStepper(
+ config=self,
+ img_shape=img_shape,
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ derive_func=derive_func,
+ )
+
+ @classmethod
+ def from_state(cls, state) -> "SingleModuleStepperConfig":
+ state = cls.remove_deprecated_keys(state)
+ return dacite.from_dict(
+ data_class=cls, data=state, config=dacite.Config(strict=True)
+ )
+
+ @property
+ def all_names(self):
+ """Names of all variables required, including auxiliary ones."""
+ extra_names = []
+ if self.ocean is not None:
+ extra_names.extend(self.ocean.forcing_names)
+ all_names = list(set(self.in_names).union(self.out_names).union(extra_names))
+ return all_names
+
+ @property
+ def normalize_names(self):
+ """Names of variables which require normalization. I.e. inputs/outputs."""
+ return list(set(self.in_names).union(self.out_names))
+
+ @property
+ def forcing_names(self) -> List[str]:
+ """Names of variables which are inputs only."""
+ return list(set(self.in_names) - set(self.out_names))
+
+ @property
+ def prognostic_names(self) -> List[str]:
+ """Names of variables which both inputs and outputs."""
+ return list(set(self.out_names).intersection(self.in_names))
+
+ @property
+ def diagnostic_names(self) -> List[str]:
+ """Names of variables which both inputs and outputs."""
+ return list(set(self.out_names).difference(self.in_names))
+
+ @classmethod
+ def remove_deprecated_keys(cls, state: Dict[str, Any]) -> Dict[str, Any]:
+ _unsupported_key_defaults = {
+ "conserve_dry_air": False,
+ "optimization": None,
+ "conservation_loss": {"dry_air_penalty": None},
+ }
+ state_copy = state.copy()
+ for key, default in _unsupported_key_defaults.items():
+ if key in state_copy:
+ if state_copy[key] == default or state_copy[key] is None:
+ del state_copy[key]
+ else:
+ raise ValueError(
+ f"The stepper config option {key} is deprecated and the setting"
+ f" provided, {state_copy[key]}, is no longer implemented. The "
+ "SingleModuleStepper being loaded from state cannot be run by "
+ "this version of the code."
+ )
+ for normalization_key in [
+ "normalization",
+ "loss_normalization",
+ "residual_normalization",
+ ]:
+ if state_copy.get(normalization_key) is not None:
+ if "exclude_names" in state_copy[normalization_key]:
+ if state_copy[normalization_key]["exclude_names"] is not None:
+ raise ValueError(
+ "The exclude_names option in normalization config is no "
+ "longer supported, but excluded names were found in "
+ f"{normalization_key}."
+ )
+ else:
+ del state_copy[normalization_key]["exclude_names"]
+ if "prescriber" in state_copy:
+ # want to maintain backwards compatibility for this particular feature
+ if state_copy["prescriber"] is not None:
+ if state_copy.get("ocean") is not None:
+ raise ValueError("Cannot specify both prescriber and ocean.")
+ state_copy["ocean"] = {
+ "surface_temperature_name": state_copy["prescriber"][
+ "prescribed_name"
+ ],
+ "ocean_fraction_name": state_copy["prescriber"]["mask_name"],
+ "interpolate": state_copy["prescriber"]["interpolate"],
+ }
+ del state_copy["prescriber"]
+ return state_copy
+
+
+@dataclasses.dataclass
+class ExistingStepperConfig:
+ """
+ Configuration for an existing stepper. This is only designed to point to
+ a serialized stepper checkpoint for loading, e.g., in the case of training
+ resumption.
+
+ Parameters:
+ checkpoint_path: The path to the serialized checkpoint.
+ """
+
+ checkpoint_path: str
+
+ def __post_init__(self):
+ self._stepper_config = SingleModuleStepperConfig.from_state(
+ self._load_checkpoint()["stepper"]["config"]
+ )
+
+ def _load_checkpoint(self) -> Mapping[str, Any]:
+ return torch.load(self.checkpoint_path, map_location=get_device())
+
+ def get_evaluation_window_data_requirements(
+ self, n_forward_steps: int
+ ) -> DataRequirements:
+ return self._stepper_config.get_evaluation_window_data_requirements(
+ n_forward_steps
+ )
+
+ def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements:
+ return self._stepper_config.get_prognostic_state_data_requirements()
+
+ def get_forcing_window_data_requirements(
+ self, n_forward_steps: int
+ ) -> DataRequirements:
+ return self._stepper_config.get_forcing_window_data_requirements(
+ n_forward_steps
+ )
+
+ def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
+ return self._stepper_config.get_base_weights()
+
+ def get_stepper(self, img_shape, gridded_operations, vertical_coordinate, timestep):
+ del img_shape # unused
+ logging.info(f"Initializing stepper from {self.checkpoint_path}")
+ return SingleModuleStepper.from_state(self._load_checkpoint()["stepper"])
+
+
+def _combine_normalizers(
+ residual_normalizer: StandardNormalizer,
+ model_normalizer: StandardNormalizer,
+) -> StandardNormalizer:
+ # Combine residual and model normalizers by overwriting the model normalizer
+ # values that are present in residual normalizer. The residual normalizer
+ # is assumed to have a subset of prognostic keys only.
+ means, stds = copy(model_normalizer.means), copy(model_normalizer.stds)
+ means.update(residual_normalizer.means)
+ stds.update(residual_normalizer.stds)
+ return StandardNormalizer(means=means, stds=stds)
+
+
+def _prepend_timesteps(
+ data: TensorMapping, timesteps: TensorMapping, time_dim: int = 1
+) -> TensorDict:
+ return {k: torch.cat([timesteps[k], v], dim=time_dim) for k, v in data.items()}
+
+
+@dataclasses.dataclass
+class TrainOutput(TrainOutputABC):
+ metrics: TensorDict
+ gen_data: TensorDict
+ target_data: TensorDict
+ time: xr.DataArray
+ normalize: Callable[[TensorDict], TensorDict]
+ derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = (
+ lambda x, _: dict(x)
+ )
+
+ def remove_initial_condition(self, n_ic_timesteps: int) -> "TrainOutput":
+ return TrainOutput(
+ metrics=self.metrics,
+ gen_data={k: v[:, n_ic_timesteps:] for k, v in self.gen_data.items()},
+ target_data={k: v[:, n_ic_timesteps:] for k, v in self.target_data.items()},
+ time=self.time[:, n_ic_timesteps:],
+ normalize=self.normalize,
+ derive_func=self.derive_func,
+ )
+
+ def copy(self) -> "TrainOutput":
+ """Creates new dictionaries for the data but with the same tensors."""
+ return TrainOutput(
+ metrics=self.metrics,
+ gen_data={k: v for k, v in self.gen_data.items()},
+ target_data={k: v for k, v in self.target_data.items()},
+ time=self.time,
+ normalize=self.normalize,
+ derive_func=self.derive_func,
+ )
+
+ def prepend_initial_condition(
+ self,
+ initial_condition: PrognosticState,
+ ) -> "TrainOutput":
+ """
+ Prepends an initial condition to the existing stepped data.
+ Assumes data are on the same device.
+ For data windows > 0, the target IC is different from the generated IC
+ and may be provided for correct calculation of tendencies.
+
+ Args:
+ initial_condition: Initial condition data.
+ """
+ batch_data = initial_condition.as_batch_data()
+ return TrainOutput(
+ metrics=self.metrics,
+ gen_data=_prepend_timesteps(self.gen_data, batch_data.data),
+ target_data=_prepend_timesteps(
+ self.target_data,
+ batch_data.data,
+ ),
+ time=xr.concat([batch_data.time, self.time], dim="time"),
+ normalize=self.normalize,
+ derive_func=self.derive_func,
+ )
+
+ def compute_derived_variables(
+ self,
+ ) -> "TrainOutput":
+ gen_data = self.derive_func(self.gen_data, self.target_data)
+ target_data = self.derive_func(self.target_data, self.target_data)
+ return TrainOutput(
+ metrics=self.metrics,
+ gen_data=gen_data,
+ target_data=target_data,
+ time=self.time,
+ normalize=self.normalize,
+ derive_func=self.derive_func,
+ )
+
+ def get_metrics(self) -> TensorDict:
+ return self.metrics
+
+
+class SingleModuleStepper(
+ TrainStepperABC[
+ PrognosticState,
+ BatchData,
+ BatchData,
+ PairedData,
+ TrainOutput,
+ ],
+):
+ """
+ Stepper class for a single pytorch module.
+ """
+
+ TIME_DIM = 1
+ CHANNEL_DIM = -3
+
+ def __init__(
+ self,
+ config: SingleModuleStepperConfig,
+ img_shape: Tuple[int, int],
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ derive_func: Callable[[TensorMapping, TensorMapping], TensorDict],
+ timestep: datetime.timedelta,
+ init_weights: bool = True,
+ ):
+ """
+ Args:
+ config: The configuration.
+ img_shape: Shape of domain as (n_lat, n_lon).
+ gridded_operations: The gridded operations, e.g. for area weighting.
+ vertical_coordinate: The vertical coordinate.
+ derive_func: Function to compute derived variables.
+ timestep: Timestep of the model.
+ init_weights: Whether to initialize the weights. Should pass False if
+ the weights are about to be overwritten by a checkpoint.
+ """
+ self._gridded_operations = gridded_operations # stored for serializing
+ n_in_channels = len(config.in_names)
+ n_out_channels = len(config.out_names)
+ self.in_packer = Packer(config.in_names)
+ self.out_packer = Packer(config.out_names)
+ self.normalizer = config.normalization.build(config.normalize_names)
+ if config.ocean is not None:
+ self.ocean: Optional[Ocean] = config.ocean.build(
+ config.in_names, config.out_names, timestep
+ )
+ else:
+ self.ocean = None
+ self.module = config.builder.build(
+ n_in_channels=n_in_channels,
+ n_out_channels=n_out_channels,
+ img_shape=img_shape,
+ )
+ module, self._l2_sp_tuning_regularizer = config.parameter_init.apply(
+ self.module, init_weights=init_weights
+ )
+ self.module = module.to(get_device())
+ self.derive_func = derive_func
+ self._img_shape = img_shape
+ self._config = config
+ self._no_optimization = NullOptimization()
+
+ dist = Distributed.get_instance()
+ self._is_distributed = dist.is_distributed()
+ self.module = dist.wrap_module(self.module)
+
+ self._vertical_coordinates = vertical_coordinate.to(get_device())
+ self._timestep = timestep
+
+ self.loss_obj = config.loss.build(
+ gridded_operations.area_weighted_mean, config.out_names, self.CHANNEL_DIM
+ )
+
+ self._corrector = config.corrector.build(
+ gridded_operations=gridded_operations,
+ vertical_coordinate=self.vertical_coordinate,
+ timestep=timestep,
+ )
+ if config.loss_normalization is not None:
+ self.loss_normalizer = config.loss_normalization.build(
+ names=config.normalize_names
+ )
+ elif config.residual_normalization is not None:
+ # Use residual norm for prognostic variables and input/output
+ # normalizer for diagnostic variables in loss
+ self.loss_normalizer = _combine_normalizers(
+ residual_normalizer=config.residual_normalization.build(
+ config.prognostic_names
+ ),
+ model_normalizer=self.normalizer,
+ )
+ else:
+ self.loss_normalizer = self.normalizer
+ self.in_names = config.in_names
+ self.out_names = config.out_names
+
+ _1: PredictFunction[ # for type checking
+ PrognosticState,
+ BatchData,
+ BatchData,
+ ] = self.predict
+
+ _2: PredictFunction[ # for type checking
+ PrognosticState,
+ BatchData,
+ PairedData,
+ ] = self.predict_paired
+
+ @property
+ def vertical_coordinate(self) -> HybridSigmaPressureCoordinate:
+ return self._vertical_coordinates
+
+ @property
+ def timestep(self) -> datetime.timedelta:
+ return self._timestep
+
+ @property
+ def surface_temperature_name(self) -> Optional[str]:
+ if self._config.ocean is not None:
+ return self._config.ocean.surface_temperature_name
+ return None
+
+ @property
+ def ocean_fraction_name(self) -> Optional[str]:
+ if self._config.ocean is not None:
+ return self._config.ocean.ocean_fraction_name
+ return None
+
+ @property
+ def effective_loss_scaling(self) -> TensorDict:
+ """
+ Effective loss scalings used to normalize outputs before computing loss.
+ y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i
+ where loss_scaling_i = loss_normalizer_std_i / weight_i.
+ """
+ custom_weights = self._config.loss.weights
+ loss_normalizer_stds = self.loss_normalizer.stds
+ return {
+ k: loss_normalizer_stds[k] / custom_weights.get(k, 1.0)
+ for k in self._config.out_names
+ }
+
+ def replace_ocean(self, ocean: Ocean):
+ """
+ Replace the ocean model with a new one.
+
+ Args:
+ ocean: The new ocean model.
+ """
+ self.ocean = ocean
+
+ @property
+ def forcing_names(self) -> List[str]:
+ """Names of variables which are inputs only."""
+ return self._config.forcing_names
+
+ @property
+ def prognostic_names(self) -> List[str]:
+ return sorted(
+ list(set(self.out_packer.names).intersection(self.in_packer.names))
+ )
+
+ @property
+ def diagnostic_names(self) -> List[str]:
+ return sorted(list(set(self.out_packer.names).difference(self.in_packer.names)))
+
+ @property
+ def n_ic_timesteps(self) -> int:
+ return 1
+
+ @property
+ def modules(self) -> nn.ModuleList:
+ """
+ Returns:
+ A list of modules being trained.
+ """
+ return nn.ModuleList([self.module])
+
+ def step(
+ self,
+ input: TensorMapping,
+ next_step_forcing_data: TensorMapping,
+ ) -> TensorDict:
+ """
+ Step the model forward one timestep given input data.
+
+ Args:
+ input: Mapping from variable name to tensor of shape
+ [n_batch, n_lat, n_lon]. This data is used as input for `self.module`
+ and is assumed to contain all input variables and be denormalized.
+ next_step_forcing_data: Mapping from variable name to tensor of shape
+ [n_batch, n_lat, n_lon]. This must contain the necessary forcing
+ data at the output timestep for the ocean model and corrector.
+
+ Returns:
+ The denormalized output data at the next time step.
+ """
+ input_norm = self.normalizer.normalize(input)
+ input_tensor = self.in_packer.pack(input_norm, axis=self.CHANNEL_DIM)
+ output_tensor = self.module(input_tensor)
+ output_norm = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM)
+ output = self.normalizer.denormalize(output_norm)
+ if self._corrector is not None:
+ output = self._corrector(input, output, next_step_forcing_data)
+ if self.ocean is not None:
+ output = self.ocean(input, output, next_step_forcing_data)
+ return output
+
+ def _predict(
+ self,
+ initial_condition: TensorMapping,
+ forcing_data: TensorMapping,
+ n_forward_steps: int,
+ ) -> TensorDict:
+ """
+ Predict multiple steps forward given initial condition and forcing data.
+
+ Uses low-level inputs and does not compute derived variables, to separate
+ concerns from the public `predict` method.
+
+ Args:
+ initial_condition: The initial condition, containing tensors of shape
+ [n_batch, self.n_ic_timesteps, ].
+ forcing_data: The forcing data, containing tensors of shape
+ [n_batch, n_forward_steps + self.n_ic_timesteps, ].
+ n_forward_steps: The number of forward steps to predict, corresponding
+ to the data shapes of forcing_data.
+
+ Returns:
+ The output data at each timestep.
+ """
+ state = {
+ k: initial_condition[k].squeeze(self.TIME_DIM) for k in initial_condition
+ }
+ ml_forcing_names = self._config.forcing_names
+ output_list = []
+ for step in range(n_forward_steps):
+ ml_input_forcing = {
+ k: (
+ forcing_data[k][:, step]
+ if k not in self._config.next_step_forcing_names
+ else forcing_data[k][:, step + 1]
+ )
+ for k in ml_forcing_names
+ }
+ next_step_forcing_data = {
+ k: forcing_data[k][:, step + 1] for k in self._forcing_names()
+ }
+ input_data = {**state, **ml_input_forcing}
+ state = self.step(input_data, next_step_forcing_data)
+ output_list.append(state)
+ output_timeseries = {}
+ for name in state:
+ output_timeseries[name] = torch.stack(
+ [x[name] for x in output_list], dim=self.TIME_DIM
+ )
+ return output_timeseries
+
+ def predict(
+ self,
+ initial_condition: PrognosticState,
+ forcing: BatchData,
+ compute_derived_variables: bool = False,
+ ) -> Tuple[BatchData, PrognosticState]:
+ """
+ Predict multiple steps forward given initial condition and reference data.
+
+ Args:
+ initial_condition: Prognostic state data with tensors of shape
+ [n_batch, self.n_ic_timesteps, ]. This data is assumed
+ to contain all prognostic variables and be denormalized.
+ forcing: Contains tensors of shape
+ [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This
+ contains the forcing and ocean data for the initial condition and all
+ subsequent timesteps.
+ compute_derived_variables: Whether to compute derived variables for the
+ prediction.
+
+ Returns:
+ A batch data containing the prediction and the prediction's final state
+ which can be used as a new initial condition.
+ """
+ timer = GlobalTimer.get_instance()
+ with timer.context("forward_prediction"):
+ forcing_data = forcing.subset_names(self._forcing_names())
+ initial_condition_state = initial_condition.as_batch_data()
+ if initial_condition_state.time.shape[1] != self.n_ic_timesteps:
+ raise ValueError(
+ f"Initial condition must have {self.n_ic_timesteps} timesteps, got "
+ f"{initial_condition_state.time.shape[1]}."
+ )
+ n_forward_steps = forcing_data.time.shape[1] - self.n_ic_timesteps
+ output_timeseries = self._predict(
+ initial_condition_state.data, forcing_data.data, n_forward_steps
+ )
+ data = BatchData.new_on_device(
+ output_timeseries,
+ forcing_data.time[:, self.n_ic_timesteps :],
+ horizontal_dims=forcing_data.horizontal_dims,
+ )
+ if compute_derived_variables:
+ with timer.context("compute_derived_variables"):
+ data = (
+ data.prepend(initial_condition)
+ .compute_derived_variables(
+ derive_func=self.derive_func,
+ forcing_data=forcing_data,
+ )
+ .remove_initial_condition(self.n_ic_timesteps)
+ )
+ return data, data.get_end(self.prognostic_names, self.n_ic_timesteps)
+
+ def predict_paired(
+ self,
+ initial_condition: PrognosticState,
+ forcing: BatchData,
+ compute_derived_variables: bool = False,
+ ) -> Tuple[PairedData, PrognosticState]:
+ """
+ Predict multiple steps forward given initial condition and reference data.
+
+ Args:
+ initial_condition: Prognostic state data with tensors of shape
+ [n_batch, self.n_ic_timesteps, ]. This data is assumed
+ to contain all prognostic variables and be denormalized.
+ forcing: Contains tensors of shape
+ [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This
+ contains the forcing and ocean data for the initial condition and all
+ subsequent timesteps.
+ compute_derived_variables: Whether to compute derived variables for the
+ prediction.
+
+ Returns:
+ A paired data containing the prediction paired with all forcing data at the
+ same timesteps and the prediction's final state which can be used as a
+ new initial condition.
+ """
+ prediction, new_initial_condition = self.predict(
+ initial_condition, forcing, compute_derived_variables
+ )
+ return (
+ PairedData.from_batch_data(
+ prediction=prediction,
+ target=self.get_forward_data(
+ forcing, compute_derived_variables=compute_derived_variables
+ ),
+ ),
+ new_initial_condition,
+ )
+
+ def get_forward_data(
+ self, data: BatchData, compute_derived_variables: bool = False
+ ) -> BatchData:
+ if compute_derived_variables:
+ timer = GlobalTimer.get_instance()
+ with timer.context("compute_derived_variables"):
+ data = data.compute_derived_variables(
+ derive_func=self.derive_func,
+ forcing_data=data,
+ )
+ return data.remove_initial_condition(self.n_ic_timesteps)
+
+ def _forcing_names(self) -> List[str]:
+ if self.ocean is None:
+ return self._config.forcing_names
+ return list(set(self._config.forcing_names).union(self.ocean.forcing_names))
+
+ def train_on_batch(
+ self,
+ data: BatchData,
+ optimization: OptimizationABC,
+ compute_derived_variables: bool = False,
+ ) -> TrainOutput:
+ """
+ Step the model forward multiple steps on a batch of data.
+
+ Args:
+ data: The batch data where each tensor in data.data has shape
+ [n_sample, n_forward_steps + self.n_ic_timesteps, ].
+ optimization: The optimization class to use for updating the module.
+ Use `NullOptimization` to disable training.
+ compute_derived_variables: Whether to compute derived variables for the
+ prediction and target data.
+
+ Returns:
+ The loss metrics, the generated data, the normalized generated data,
+ and the normalized batch data.
+ """
+ time_dim = self.TIME_DIM
+
+ loss = torch.tensor(0.0, device=get_device())
+ metrics: Dict[str, float] = {}
+ input_data = data.get_start(self.prognostic_names, self.n_ic_timesteps)
+
+ optimization.set_mode(self.module)
+ with optimization.autocast():
+ # output from self.predict does not include initial condition
+ output, _ = self.predict_paired(
+ input_data,
+ forcing=data,
+ )
+ gen_data = output.prediction
+ target_data = output.target
+ n_forward_steps = output.time.shape[1]
+
+ # compute loss for each timestep
+ for step in range(n_forward_steps):
+ # Note: here we examine the loss for a single timestep,
+ # not a single model call (which may contain multiple timesteps).
+ gen_step = {k: v.select(time_dim, step) for k, v in gen_data.items()}
+ target_step = {
+ k: v.select(time_dim, step) for k, v in target_data.items()
+ }
+ gen_norm_step = self.loss_normalizer.normalize(gen_step)
+ target_norm_step = self.loss_normalizer.normalize(target_step)
+
+ step_loss = self.loss_obj(gen_norm_step, target_norm_step)
+ loss += step_loss
+ metrics[f"loss_step_{step}"] = step_loss.detach()
+
+ loss += self._l2_sp_tuning_regularizer()
+
+ metrics["loss"] = loss.detach()
+ optimization.step_weights(loss)
+
+ stepped = TrainOutput(
+ metrics=metrics,
+ gen_data=dict(gen_data),
+ target_data=dict(target_data),
+ time=output.time,
+ normalize=self.normalizer.normalize,
+ derive_func=self.derive_func,
+ )
+ ic = data.get_start(
+ set(data.data.keys()), self.n_ic_timesteps
+ ) # full data and not just prognostic get prepended
+ stepped = stepped.prepend_initial_condition(ic)
+ if compute_derived_variables:
+ stepped = stepped.compute_derived_variables()
+ return stepped
+
+ def get_state(self):
+ """
+ Returns:
+ The state of the stepper.
+ """
+ return {
+ "module": self.module.state_dict(),
+ "normalizer": self.normalizer.get_state(),
+ "img_shape": self._img_shape,
+ "config": self._config.get_state(),
+ "gridded_operations": self._gridded_operations.to_state(),
+ "vertical_coordinate": self.vertical_coordinate.as_dict(),
+ "encoded_timestep": encode_timestep(self.timestep),
+ "loss_normalizer": self.loss_normalizer.get_state(),
+ }
+
+ def load_state(self, state: Dict[str, Any]) -> None:
+ """
+ Load the state of the stepper.
+
+ Args:
+ state: The state to load.
+ """
+ if "module" in state:
+ module = state["module"]
+ if "module.device_buffer" in module:
+ # for backwards compatibility with old checkpoints
+ del module["module.device_buffer"]
+ self.module.load_state_dict(module)
+
+ @classmethod
+ def from_state(cls, state) -> "SingleModuleStepper":
+ """
+ Load the state of the stepper.
+
+ Args:
+ state: The state to load.
+
+ Returns:
+ The stepper.
+ """
+ config = {**state["config"]} # make a copy to avoid mutating input
+ config["normalization"] = state["normalizer"]
+
+ # for backwards compatibility with previous steppers created w/o
+ # loss_normalization or residual_normalization
+ loss_normalizer_state = state.get("loss_normalizer", state["normalizer"])
+ config["loss_normalization"] = loss_normalizer_state
+
+ # Overwrite the residual_normalization key if it exists, since the combined
+ # loss scalings are saved in initial training as the loss_normalization
+ config["residual_normalization"] = None
+
+ if "area" in state:
+ # backwards-compatibility, these older checkpoints are always lat-lon
+ gridded_operations: GriddedOperations = LatLonOperations(state["area"])
+ else:
+ gridded_operations = GriddedOperations.from_state(
+ state["gridded_operations"]
+ )
+
+ if "sigma_coordinates" in state:
+ # for backwards compatibility with old checkpoints
+ state["vertical_coordinate"] = state["sigma_coordinates"]
+
+ vertical_coordinate = dacite.from_dict(
+ data_class=HybridSigmaPressureCoordinate,
+ data=state["vertical_coordinate"],
+ config=dacite.Config(strict=True),
+ )
+ # for backwards compatibility with original ACE checkpoint which
+ # serialized vertical coordinates as float64
+ if vertical_coordinate.ak.dtype == torch.float64:
+ vertical_coordinate.ak = vertical_coordinate.ak.to(dtype=torch.float32)
+ if vertical_coordinate.bk.dtype == torch.float64:
+ vertical_coordinate.bk = vertical_coordinate.bk.to(dtype=torch.float32)
+ encoded_timestep = state.get("encoded_timestep", DEFAULT_ENCODED_TIMESTEP)
+ timestep = decode_timestep(encoded_timestep)
+ if "img_shape" in state:
+ img_shape = state["img_shape"]
+ else:
+ # this is for backwards compatibility with old checkpoints
+ for v in state["data_shapes"].values():
+ img_shape = v[-2:]
+ break
+ derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep)
+ stepper = cls(
+ config=SingleModuleStepperConfig.from_state(config),
+ img_shape=img_shape,
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ derive_func=derive_func,
+ # don't need to initialize weights, we're about to load_state
+ init_weights=False,
+ )
+ stepper.load_state(state)
+ return stepper
diff --git a/fme/fme/core/test_stepper.py b/fme/fme/ace/test_stepper.py
similarity index 56%
rename from fme/fme/core/test_stepper.py
rename to fme/fme/ace/test_stepper.py
index b45a5ed..2ed9f6f 100644
--- a/fme/fme/core/test_stepper.py
+++ b/fme/fme/ace/test_stepper.py
@@ -3,46 +3,57 @@
from typing import Iterable, List, Literal, Optional, Tuple, Union
from unittest.mock import MagicMock
+import cftime
import numpy as np
import pytest
import torch
+import xarray as xr
import fme
-from fme.ace.inference.derived_variables import compute_stepped_derived_quantities
+from fme.ace.aggregator import OneStepAggregator
+from fme.ace.aggregator.plotting import plot_paneled_data
+from fme.ace.data_loading.batch_data import BatchData, PrognosticState
+from fme.ace.stepper import (
+ CorrectorConfig,
+ SingleModuleStepper,
+ SingleModuleStepperConfig,
+ TrainOutput,
+ _combine_normalizers,
+)
from fme.core import ClimateData, metrics
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from fme.core.loss import WeightedMappingLossConfig
from fme.core.normalizer import NormalizationConfig, StandardNormalizer
from fme.core.ocean import OceanConfig, SlabOceanConfig
from fme.core.optimization import NullOptimization, Optimization, OptimizationConfig
-from fme.core.registry import ModuleSelector
-from fme.core.stepper import (
- CorrectorConfig,
- SingleModuleStepper,
- SingleModuleStepperConfig,
- SteppedData,
- _combine_normalizers,
-)
+from fme.core.registry.module import ModuleSelector
from fme.core.typing_ import TensorDict
-SphericalData = namedtuple("SphericalData", ["data", "area_weights", "sigma_coords"])
+SphericalData = namedtuple("SphericalData", ["data", "area_weights", "vertical_coord"])
TIMESTEP = datetime.timedelta(hours=6)
+DEVICE = fme.get_device()
def get_data(names: Iterable[str], n_samples, n_time) -> SphericalData:
- data = {}
+ data_dict = {}
n_lat, n_lon, nz = 5, 5, 7
lats = torch.linspace(-89.5, 89.5, n_lat) # arbitary choice
for name in names:
- data[name] = torch.rand(
- n_samples, n_time, n_lat, n_lon, device=fme.get_device()
- )
- area_weights = fme.spherical_area_weights(lats, n_lon).to(fme.get_device())
+ data_dict[name] = torch.rand(n_samples, n_time, n_lat, n_lon, device=DEVICE)
+ area_weights = fme.spherical_area_weights(lats, n_lon).to(DEVICE)
ak, bk = torch.arange(nz), torch.arange(nz)
- sigma_coords = SigmaCoordinates(ak, bk)
- return SphericalData(data, area_weights, sigma_coords)
+ vertical_coord = HybridSigmaPressureCoordinate(ak, bk)
+ data = BatchData.new_on_device(
+ data=data_dict,
+ time=xr.DataArray(
+ np.zeros((n_samples, n_time)),
+ dims=["sample", "time"],
+ ),
+ )
+ return SphericalData(data, area_weights, vertical_coord)
def get_scalar_data(names, value):
@@ -92,42 +103,53 @@ def test_stepper_config_all_names_property(
assert set(config.all_names) == set(expected_all_names)
-def test_run_on_batch_normalizer_changes_only_norm_data():
+def test_train_on_batch_normalizer_changes_only_norm_data():
torch.manual_seed(0)
data = get_data(["a", "b"], n_samples=5, n_time=2).data
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ area = torch.ones((5, 5), device=DEVICE)
+ gridded_operations = LatLonOperations(area)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
+ normalization_config = NormalizationConfig(
+ means=get_scalar_data(["a", "b"], 0.0),
+ stds=get_scalar_data(["a", "b"], 1.0),
+ )
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config={"module": torch.nn.Identity()}),
in_names=["a", "b"],
out_names=["a", "b"],
- normalization=NormalizationConfig(
- means=get_scalar_data(["a", "b"], 0.0),
- stds=get_scalar_data(["a", "b"], 1.0),
- ),
+ normalization=normalization_config,
loss=WeightedMappingLossConfig(type="MSE"),
)
- stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP)
- stepped = stepper.run_on_batch(data=data, optimization=MagicMock())
+ stepper = config.get_stepper(
+ (5, 5), gridded_operations, vertical_coordinate, TIMESTEP
+ )
+ stepped = stepper.train_on_batch(data=data, optimization=MagicMock())
assert torch.allclose(
- stepped.gen_data["a"], stepped.gen_data_norm["a"]
+ stepped.gen_data["a"], stepped.normalize(stepped.gen_data)["a"]
) # as std=1, mean=0, no change
- config.normalization.stds = get_scalar_data(["a", "b"], 2.0)
+ normalization_config.stds = get_scalar_data(["a", "b"], 2.0)
+ config.normalization = normalization_config
config.loss_normalization = NormalizationConfig(
means=get_scalar_data(["a", "b"], 0.0),
stds=get_scalar_data(["a", "b"], 3.0),
)
- stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP)
- stepped_double_std = stepper.run_on_batch(data=data, optimization=MagicMock())
+ stepper = config.get_stepper(
+ (5, 5), gridded_operations, vertical_coordinate, TIMESTEP
+ )
+ stepped_double_std = stepper.train_on_batch(data=data, optimization=MagicMock())
assert torch.allclose(
stepped.gen_data["a"], stepped_double_std.gen_data["a"], rtol=1e-4
)
assert torch.allclose(
- stepped.gen_data["a"], 2.0 * stepped_double_std.gen_data_norm["a"], rtol=1e-4
+ stepped.gen_data["a"],
+ 2.0 * stepped_double_std.normalize(stepped_double_std.gen_data)["a"],
+ rtol=1e-4,
)
assert torch.allclose(
stepped.target_data["a"],
- 2.0 * stepped_double_std.target_data_norm["a"],
+ 2.0 * stepped_double_std.normalize(stepped_double_std.target_data)["a"],
rtol=1e-4,
)
assert torch.allclose(
@@ -135,7 +157,7 @@ def test_run_on_batch_normalizer_changes_only_norm_data():
) # mse scales with std**2
-def test_run_on_batch_addition_series():
+def test_train_on_batch_addition_series():
torch.manual_seed(0)
class AddOne(torch.nn.Module):
@@ -143,9 +165,12 @@ def forward(self, x):
return x + 1
n_steps = 4
- data_with_ic = get_data(["a", "b"], n_samples=5, n_time=n_steps + 1).data
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ data_with_ic: BatchData = get_data(["a", "b"], n_samples=5, n_time=n_steps + 1).data
+ area = torch.ones((5, 5), device=DEVICE)
+ gridded_operations = LatLonOperations(area)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config={"module": AddOne()}),
in_names=["a", "b"],
@@ -156,20 +181,21 @@ def forward(self, x):
),
loss=WeightedMappingLossConfig(type="MSE"),
)
- stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP)
- stepped = stepper.run_on_batch(
- data=data_with_ic, optimization=MagicMock(), n_forward_steps=n_steps
+ stepper = config.get_stepper(
+ (5, 5), gridded_operations, vertical_coordinate, TIMESTEP
)
- # output of run_on_batch does not include the initial condition
- assert stepped.gen_data["a"].shape == (5, n_steps, 5, 5)
- data = {k: data_with_ic[k][:, 1:] for k in data_with_ic}
+ stepped = stepper.train_on_batch(data=data_with_ic, optimization=MagicMock())
+ # output of train_on_batch does not include the initial condition
+ assert stepped.gen_data["a"].shape == (5, n_steps + 1, 5, 5)
for i in range(n_steps - 1):
assert torch.allclose(
- stepped.gen_data_norm["a"][:, i] + 1, stepped.gen_data_norm["a"][:, i + 1]
+ stepped.normalize(stepped.gen_data)["a"][:, i] + 1,
+ stepped.normalize(stepped.gen_data)["a"][:, i + 1],
)
assert torch.allclose(
- stepped.gen_data_norm["b"][:, i] + 1, stepped.gen_data_norm["b"][:, i + 1]
+ stepped.normalize(stepped.gen_data)["b"][:, i] + 1,
+ stepped.normalize(stepped.gen_data)["b"][:, i + 1],
)
assert torch.allclose(
stepped.gen_data["a"][:, i] + 1, stepped.gen_data["a"][:, i + 1]
@@ -177,11 +203,17 @@ def forward(self, x):
assert torch.allclose(
stepped.gen_data["b"][:, i] + 1, stepped.gen_data["b"][:, i + 1]
)
- assert torch.allclose(stepped.target_data_norm["a"], data["a"])
- assert torch.allclose(stepped.target_data_norm["b"], data["b"])
+ assert torch.allclose(
+ stepped.normalize(stepped.target_data)["a"],
+ data_with_ic.data["a"],
+ )
+ assert torch.allclose(
+ stepped.normalize(stepped.target_data)["b"],
+ data_with_ic.data["b"],
+ )
-def test_run_on_batch_with_prescribed_ocean():
+def test_train_on_batch_with_prescribed_ocean():
torch.manual_seed(0)
class AddOne(torch.nn.Module):
@@ -189,45 +221,48 @@ def forward(self, x):
return x + 1
n_steps = 3
- data = get_data(["a", "b", "mask"], n_samples=5, n_time=n_steps + 1).data
- data["mask"] = torch.zeros_like(data["mask"], dtype=torch.int)
- data["mask"][:, :, :, 0] = 1
+ data: BatchData = get_data(["a", "b", "mask"], n_samples=5, n_time=n_steps + 1).data
+ data.data["mask"][:] = 0
+ data.data["mask"][:, :, :, 0] = 1
stds = {
"a": np.array([2.0], dtype=np.float32),
"b": np.array([3.0], dtype=np.float32),
- "mask": np.array([1.0], dtype=np.float32),
}
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ area = torch.ones((5, 5), device=DEVICE)
+ gridded_operations = LatLonOperations(area)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config={"module": AddOne()}),
in_names=["a", "b"],
out_names=["a", "b"],
normalization=NormalizationConfig(
- means=get_scalar_data(["a", "b", "mask"], 0.0),
+ means=get_scalar_data(["a", "b"], 0.0),
stds=stds,
),
ocean=OceanConfig("b", "mask"),
)
- stepper = config.get_stepper(area.shape, area, sigma_coordinates, TIMESTEP)
- stepped = stepper.run_on_batch(
- data, optimization=MagicMock(), n_forward_steps=n_steps
+ stepper = config.get_stepper(
+ area.shape, gridded_operations, vertical_coordinate, TIMESTEP
)
+ stepped = stepper.train_on_batch(data, optimization=MagicMock())
for i in range(n_steps - 1):
# "a" should be increasing by 1 according to AddOne
torch.testing.assert_close(
- stepped.gen_data_norm["a"][:, i] + 1, stepped.gen_data_norm["a"][:, i + 1]
+ stepped.normalize(stepped.gen_data)["a"][:, i] + 1,
+ stepped.normalize(stepped.gen_data)["a"][:, i + 1],
)
# "b" should be increasing by 1 where the mask says don't prescribe
# note the 1: selection for the last dimension in following two assertions
torch.testing.assert_close(
- stepped.gen_data_norm["b"][:, i, :, 1:] + 1,
- stepped.gen_data_norm["b"][:, i + 1, :, 1:],
+ stepped.normalize(stepped.gen_data)["b"][:, i, :, 1:] + 1,
+ stepped.normalize(stepped.gen_data)["b"][:, i + 1, :, 1:],
)
# now check that the 0th index in last dimension has been overwritten
torch.testing.assert_close(
- stepped.gen_data_norm["b"][:, i, :, 0],
- stepped.target_data_norm["b"][:, i, :, 0],
+ stepped.normalize(stepped.gen_data)["b"][:, i, :, 0],
+ stepped.normalize({"b": stepped.target_data["b"]})["b"][:, i, :, 0],
)
@@ -245,47 +280,33 @@ def test_reloaded_stepper_gives_same_prediction():
),
)
shapes = {
- "a": (1, 1, 5, 5),
- "b": (1, 1, 5, 5),
+ "a": (1, 2, 5, 5),
+ "b": (1, 2, 5, 5),
}
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ area = torch.ones((5, 5), device=DEVICE)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
stepper = config.get_stepper(
img_shape=shapes["a"][-2:],
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
- area = torch.ones((5, 5), device=fme.get_device())
- new_stepper = SingleModuleStepper.from_state(
- stepper.get_state(), area=area, sigma_coordinates=sigma_coordinates
- )
+ area = torch.ones((5, 5), device=DEVICE)
+ new_stepper = SingleModuleStepper.from_state(stepper.get_state())
data = get_data(["a", "b"], n_samples=5, n_time=2).data
- first_result = stepper.run_on_batch(
+ first_result = stepper.train_on_batch(
data=data,
optimization=NullOptimization(),
- n_forward_steps=1,
)
- second_result = new_stepper.run_on_batch(
+ second_result = new_stepper.train_on_batch(
data=data,
optimization=NullOptimization(),
- n_forward_steps=1,
)
assert torch.allclose(first_result.metrics["loss"], second_result.metrics["loss"])
assert torch.allclose(first_result.gen_data["a"], second_result.gen_data["a"])
assert torch.allclose(first_result.gen_data["b"], second_result.gen_data["b"])
- assert torch.allclose(
- first_result.gen_data_norm["a"], second_result.gen_data_norm["a"]
- )
- assert torch.allclose(
- first_result.gen_data_norm["b"], second_result.gen_data_norm["b"]
- )
- assert torch.allclose(
- first_result.target_data_norm["a"], second_result.target_data_norm["a"]
- )
- assert torch.allclose(
- first_result.target_data_norm["b"], second_result.target_data_norm["b"]
- )
assert torch.allclose(first_result.target_data["a"], second_result.target_data["a"])
assert torch.allclose(first_result.target_data["b"], second_result.target_data["b"])
@@ -314,24 +335,25 @@ def forward(self, x):
return zero + self._param
-def _setup_and_run_on_batch(
+def _setup_and_train_on_batch(
data: TensorDict,
in_names,
out_names,
ocean_config: Optional[OceanConfig],
- n_forward_steps,
optimization_config: Optional[OptimizationConfig],
):
- """Sets up the requisite classes to run run_on_batch."""
+ """Sets up the requisite classes to run train_on_batch."""
module = ReturnZerosModule(len(in_names), len(out_names))
if optimization_config is None:
optimization: Union[NullOptimization, Optimization] = NullOptimization()
else:
- optimization = optimization_config.build(module.parameters(), 2)
+ optimization = optimization_config.build(modules=[module], max_epochs=2)
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ area = torch.ones((5, 5), device=DEVICE)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config={"module": module}),
in_names=in_names,
@@ -342,10 +364,10 @@ def _setup_and_run_on_batch(
),
ocean=ocean_config,
)
- stepper = config.get_stepper(area.shape, area, sigma_coordinates, TIMESTEP)
- return stepper.run_on_batch(
- data, optimization=optimization, n_forward_steps=n_forward_steps
+ stepper = config.get_stepper(
+ area.shape, LatLonOperations(area), vertical_coordinate, TIMESTEP
)
+ return stepper.train_on_batch(data, optimization=optimization)
@pytest.mark.parametrize(
@@ -358,7 +380,7 @@ def _setup_and_run_on_batch(
)
@pytest.mark.parametrize("n_forward_steps", [1, 2, 3], ids=lambda p: f"k={p}")
@pytest.mark.parametrize("is_train", [True, False], ids=["is_train", ""])
-def test_run_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescribed):
+def test_train_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescribed):
in_names, out_names = ["a"], ["a"]
if is_input:
in_names.append("b")
@@ -374,15 +396,59 @@ def test_run_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescri
else:
ocean_config = None
- data, area_weights, sigma_coords = get_data(all_names, 3, n_forward_steps + 1)
+ data, _, _ = get_data(all_names, 3, n_forward_steps + 1)
if is_train:
optimization = OptimizationConfig()
else:
optimization = None
- _setup_and_run_on_batch(
- data, in_names, out_names, ocean_config, n_forward_steps, optimization
+ _setup_and_train_on_batch(data, in_names, out_names, ocean_config, optimization)
+
+
+@pytest.mark.parametrize("n_forward_steps", [1, 2, 3])
+def test_train_on_batch_one_step_aggregator(n_forward_steps):
+ in_names, out_names, all_names = ["a"], ["a"], ["a"]
+ data, _, _ = get_data(all_names, 3, n_forward_steps + 1)
+ stepper = _get_stepper(in_names, out_names, ocean_config=None, module_name="AddOne")
+
+ aggregator = OneStepAggregator(
+ gridded_operations=LatLonOperations(torch.ones((5, 5))),
+ )
+
+ stepped = stepper.train_on_batch(data, optimization=NullOptimization())
+ assert stepped.gen_data["a"].shape[1] == n_forward_steps + 1
+
+ aggregator.record_batch(stepped)
+ logs = aggregator.get_logs("one_step")
+
+ gen = data.data["a"].select(dim=1, index=0) + 1
+ tar = data.data["a"].select(dim=1, index=1)
+
+ bias = torch.mean(gen - tar)
+ assert np.isclose(bias.item(), logs["one_step/mean/weighted_bias/a"])
+
+ residual_gen = torch.ones((5, 5))
+ residual_tar = tar[0] - data.data["a"].select(dim=1, index=0)[0]
+ residual_imgs = [[residual_gen.cpu().numpy()], [residual_tar.cpu().numpy()]]
+ residual_plot = plot_paneled_data(residual_imgs, diverging=True)
+ assert np.allclose(
+ residual_plot.to_data_array(),
+ logs["one_step/snapshot/image-residual/a"].to_data_array(),
+ )
+
+ full_field_gen = gen.mean(dim=0)
+ full_field_tar = tar.mean(dim=0)
+ full_field_plot = plot_paneled_data(
+ [
+ [full_field_gen.cpu().numpy()],
+ [full_field_tar.cpu().numpy()],
+ ],
+ diverging=False,
+ )
+ assert np.allclose(
+ full_field_plot.to_data_array(),
+ logs["one_step/mean_map/image-full-field/a"].to_data_array(),
)
@@ -406,24 +472,28 @@ def forward(self, x):
(False, "advection_and_precipitation", True),
],
)
-def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: bool):
+@pytest.mark.parametrize("compute_derived_in_train_on_batch", [False, True])
+def test_stepper_corrector(
+ global_only: bool,
+ terms_to_modify,
+ force_positive: bool,
+ compute_derived_in_train_on_batch: bool,
+):
torch.random.manual_seed(0)
n_forward_steps = 5
device = get_device()
data = {
- "PRESsfc": 10.0 + torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device),
+ "PRESsfc": 10.0 + torch.rand(size=(3, n_forward_steps + 1, 5, 5)),
"specific_total_water_0": -0.2
- + torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device),
- "specific_total_water_1": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(
- device
- ),
- "PRATEsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device),
- "LHTFLsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device),
+ + torch.rand(size=(3, n_forward_steps + 1, 5, 5)),
+ "specific_total_water_1": torch.rand(size=(3, n_forward_steps + 1, 5, 5)),
+ "PRATEsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)),
+ "LHTFLsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)),
"tendency_of_total_water_path_due_to_advection": torch.rand(
size=(3, n_forward_steps + 1, 5, 5)
- ).to(device),
+ ),
}
- sigma_coordinates = SigmaCoordinates(
+ vertical_coordinate = HybridSigmaPressureCoordinate(
ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0])
).to(device)
area_weights = 1.0 + torch.rand(size=(5, 5)).to(device)
@@ -441,14 +511,12 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b
)
mean_advection = metrics.weighted_mean(
- data["tendency_of_total_water_path_due_to_advection"],
+ data["tendency_of_total_water_path_due_to_advection"].to(device),
weights=area_weights,
dim=[-2, -1],
)
assert (mean_advection.abs() > 0.0).all()
- # use a randomly initialized Linear layer for the module
- # using PrebuiltBuilder
stepper_config = SingleModuleStepperConfig(
builder=ModuleSelector(
type="prebuilt",
@@ -466,22 +534,35 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b
)
stepper = stepper_config.get_stepper(
img_shape=data["PRESsfc"].shape[2:],
- area=area_weights,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area_weights),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
+ time = xr.DataArray(
+ [
+ [
+ cftime.DatetimeProlepticGregorian(
+ 2000, 1, int(i * 6 // 24) + 1, i * 6 % 24
+ )
+ for i in range(n_forward_steps + 1)
+ ]
+ for _ in range(3)
+ ],
+ dims=["sample", "time"],
+ )
+ batch_data = BatchData.new_on_cpu(
+ data=data,
+ time=time,
+ ).to_device()
# run the stepper on the data
with torch.no_grad():
- stepped = stepper.run_on_batch(
- data=data,
+ stepped = stepper.train_on_batch(
+ data=batch_data,
optimization=NullOptimization(),
- n_forward_steps=n_forward_steps,
+ compute_derived_variables=compute_derived_in_train_on_batch,
)
-
- stepped = compute_stepped_derived_quantities(
- stepped, sigma_coordinates=sigma_coordinates, timestep=TIMESTEP
- )
-
+ if not compute_derived_in_train_on_batch:
+ stepped = stepped.compute_derived_variables()
# check that the budget residual is zero
budget_residual = stepped.gen_data["total_water_path_budget_residual"]
if global_only:
@@ -516,7 +597,7 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b
dry_air = (
metrics.weighted_mean(
ClimateData(stepped.gen_data).surface_pressure_due_to_dry_air(
- sigma_coordinates
+ vertical_coordinate
),
weights=area_weights,
dim=[-2, -1],
@@ -530,7 +611,7 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b
# check that positive forcing is enforced
if force_positive:
for name in force_positive_names:
- assert stepped.gen_data[name].min() >= 0.0
+ assert stepped.gen_data[name][:, 1:].min() >= 0.0
def _get_stepper(
@@ -569,7 +650,9 @@ def forward(self, x):
all_names = list(set(in_names + out_names))
area = torch.ones((5, 5))
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
config = SingleModuleStepperConfig(
builder=ModuleSelector(type="prebuilt", config=module_config),
in_names=in_names,
@@ -581,12 +664,14 @@ def forward(self, x):
ocean=ocean_config,
**kwargs,
)
- return config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP)
+ return config.get_stepper(
+ (5, 5), LatLonOperations(area), vertical_coordinate, TIMESTEP
+ )
def test_step():
stepper = _get_stepper(["a", "b"], ["a", "b"])
- input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]}
+ input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]}
output = stepper.step(input_data, {})
@@ -596,7 +681,7 @@ def test_step():
def test_step_with_diagnostic():
stepper = _get_stepper(["a"], ["a", "c"], module_name="RepeatChannel")
- input_data = {"a": torch.rand(3, 5, 5)}
+ input_data = {"a": torch.rand(3, 5, 5).to(DEVICE)}
output = stepper.step(input_data, {})
torch.testing.assert_close(output["a"], input_data["a"])
torch.testing.assert_close(output["c"], input_data["a"])
@@ -604,7 +689,7 @@ def test_step_with_diagnostic():
def test_step_with_forcing_and_diagnostic():
stepper = _get_stepper(["a", "b"], ["a", "c"])
- input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]}
+ input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]}
output = stepper.step(input_data, {})
torch.testing.assert_close(output["a"], input_data["a"] + 1)
assert "b" not in output
@@ -615,8 +700,8 @@ def test_step_with_prescribed_ocean():
stepper = _get_stepper(
["a", "b"], ["a", "b"], ocean_config=OceanConfig("a", "mask")
)
- input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b", "mask"]}
- ocean_data = {x: torch.rand(3, 5, 5) for x in ["a", "mask"]}
+ input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]}
+ ocean_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "mask"]}
output = stepper.step(input_data, ocean_data)
expected_a_output = torch.where(
torch.round(ocean_data["mask"]).to(int) == 1,
@@ -628,51 +713,105 @@ def test_step_with_prescribed_ocean():
assert set(output) == {"a", "b"}
+def get_data_for_predict(
+ n_steps, forcing_names: List[str]
+) -> Tuple[PrognosticState, BatchData]:
+ n_samples = 3
+ input_data = BatchData.new_on_device(
+ data={"a": torch.rand(n_samples, 1, 5, 5).to(DEVICE)},
+ time=xr.DataArray(
+ np.zeros((n_samples, 1)),
+ dims=["sample", "time"],
+ ),
+ ).get_start(
+ prognostic_names=["a"],
+ n_ic_timesteps=1,
+ )
+ forcing_data = BatchData.new_on_device(
+ data={
+ name: torch.rand(3, n_steps + 1, 5, 5).to(DEVICE) for name in forcing_names
+ },
+ time=xr.DataArray(
+ np.zeros((n_samples, n_steps + 1)),
+ dims=["sample", "time"],
+ ),
+ )
+ return input_data, forcing_data
+
+
def test_predict():
- stepper = _get_stepper(["a", "b"], ["a", "b"])
+ stepper = _get_stepper(["a"], ["a"])
n_steps = 3
- input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]}
- forcing_data = {}
- output = stepper.predict(input_data, forcing_data, n_steps)
- for variable in ["a", "b"]:
- assert output[variable].size(dim=1) == n_steps
- torch.testing.assert_close(
- output[variable][:, -1], input_data[variable] + n_steps
- )
+ input_data, forcing_data = get_data_for_predict(n_steps, forcing_names=[])
+ forcing_data.data = {}
+ output, new_input_data = stepper.predict(input_data, forcing_data)
+ xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time)
+ variable = "a"
+ assert output.data[variable].size(dim=1) == n_steps
+ torch.testing.assert_close(
+ output.data[variable][:, -1],
+ input_data.as_batch_data().data[variable][:, 0] + n_steps,
+ )
+ assert isinstance(new_input_data, PrognosticState)
+ new_input_state = new_input_data.as_batch_data()
+ assert isinstance(new_input_state, BatchData)
+ torch.testing.assert_close(
+ new_input_state.data[variable][:, 0], output.data[variable][:, -1]
+ )
+ assert new_input_state.time.equals(output.time[:, -1:])
def test_predict_with_forcing():
stepper = _get_stepper(["a", "b"], ["a"], module_name="ChannelSum")
n_steps = 3
- input_data = {"a": torch.rand(3, 5, 5)}
- forcing_data = {"b": torch.rand(3, n_steps + 1, 5, 5)}
- output = stepper.predict(input_data, forcing_data, n_steps)
- assert "b" not in output
- assert output["a"].size(dim=1) == n_steps
+ input_data, forcing_data = get_data_for_predict(n_steps, forcing_names=["b"])
+ output, new_input_data = stepper.predict(input_data, forcing_data)
+ assert "b" not in output.data
+ assert output.data["a"].size(dim=1) == n_steps
+ xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time)
torch.testing.assert_close(
- output["a"][:, 0], input_data["a"] + forcing_data["b"][:, 0]
- )
+ output.data["a"][:, 0],
+ input_data.as_batch_data().data["a"][:, 0] + forcing_data.data["b"][:, 0],
+ )
+ assert isinstance(new_input_data, PrognosticState)
+ new_input_state = new_input_data.as_batch_data()
+ assert isinstance(new_input_state, BatchData)
+ torch.testing.assert_close(new_input_state.data["a"][:, 0], output.data["a"][:, -1])
+ assert "b" not in new_input_state.data
for n in range(1, n_steps):
- expected_a_output = output["a"][:, n - 1] + forcing_data["b"][:, n]
- torch.testing.assert_close(output["a"][:, n], expected_a_output)
+ expected_a_output = output.data["a"][:, n - 1] + forcing_data.data["b"][:, n]
+ torch.testing.assert_close(output.data["a"][:, n], expected_a_output)
+ xr.testing.assert_equal(output.time, forcing_data.time[:, 1:])
+ assert new_input_state.time.equals(output.time[:, -1:])
def test_predict_with_ocean():
stepper = _get_stepper(["a"], ["a"], ocean_config=OceanConfig("a", "mask"))
n_steps = 3
- input_data = {"a": torch.rand(3, 5, 5)}
- forcing_data = {x: torch.rand(3, n_steps + 1, 5, 5) for x in ["a", "mask"]}
- output = stepper.predict(input_data, forcing_data, n_steps)
- assert "mask" not in output
- assert output["a"].size(dim=1) == n_steps
+ input_data, forcing_data = get_data_for_predict(
+ n_steps, forcing_names=["a", "mask"]
+ )
+ output, new_input_data = stepper.predict(input_data, forcing_data)
+ xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time)
+ assert "mask" not in output.data
+ assert output.data["a"].size(dim=1) == n_steps
for n in range(n_steps):
- previous_a = input_data["a"] if n == 0 else output["a"][:, n - 1]
+ previous_a = (
+ input_data.as_batch_data().data["a"][:, 0]
+ if n == 0
+ else output.data["a"][:, n - 1]
+ )
expected_a_output = torch.where(
- torch.round(forcing_data["mask"][:, n + 1]).to(int) == 1,
- forcing_data["a"][:, n + 1],
+ torch.round(forcing_data.data["mask"][:, n + 1]).to(int) == 1,
+ forcing_data.data["a"][:, n + 1],
previous_a + 1,
)
- torch.testing.assert_close(output["a"][:, n], expected_a_output)
+ torch.testing.assert_close(output.data["a"][:, n], expected_a_output)
+ assert isinstance(new_input_data, PrognosticState)
+ new_input_state = new_input_data.as_batch_data()
+ assert isinstance(new_input_state, BatchData)
+ torch.testing.assert_close(new_input_state.data["a"][:, 0], output.data["a"][:, -1])
+ assert new_input_state.time.equals(output.time[:, -1:])
def test_next_step_forcing_names():
@@ -682,39 +821,46 @@ def test_next_step_forcing_names():
module_name="ChannelSum",
next_step_forcing_names=["c"],
)
- input_data = {x: torch.rand(1, 5, 5) for x in ["a"]}
- forcing_data = {x: torch.rand(1, 2, 5, 5) for x in ["b", "c"]}
- stepper.predict(input_data, forcing_data, 1)
+ input_data, forcing_data = get_data_for_predict(n_steps=1, forcing_names=["b", "c"])
+ stepper.predict(input_data, forcing_data)
torch.testing.assert_close(
- stepper.module.module.last_input[:, 1, :], forcing_data["b"][:, 0]
+ stepper.module.module.last_input[:, 1, :], forcing_data.data["b"][:, 0]
)
torch.testing.assert_close(
- stepper.module.module.last_input[:, 2, :], forcing_data["c"][:, 1]
+ stepper.module.module.last_input[:, 2, :], forcing_data.data["c"][:, 1]
)
def test_prepend_initial_condition():
nt = 3
- x = torch.rand(3, nt, 5).to(fme.get_device())
- x_normed = (x - x.mean()) / x.std()
- stepped = SteppedData(
+ x = torch.rand(3, nt, 5).to(DEVICE)
+
+ def normalize(x):
+ result = {k: (v - 1) / 2 for k, v in x.items()}
+ return result
+
+ stepped = TrainOutput(
gen_data={"a": x, "b": x + 1},
- gen_data_norm={"a": x_normed, "b": x_normed + 1},
- target_data={"a": x, "b": x + 1},
- target_data_norm={"a": x_normed, "b": x_normed + 1},
+ target_data={"a": x + 2, "b": x + 3},
+ time=xr.DataArray(np.zeros((3, nt)), dims=["sample", "time"]),
metrics={"loss": torch.tensor(0.0)},
+ normalize=normalize,
)
- ic = {
- "a": torch.rand(3, 5).to(fme.get_device()),
- "b": torch.rand(3, 5).to(fme.get_device()),
+ ic_data = {
+ "a": torch.rand(3, 1, 5).to(DEVICE),
+ "b": torch.rand(3, 1, 5).to(DEVICE),
}
- ic_normed = {k: (v - v.mean()) / v.std() for k, v in ic.items()}
- prepended = stepped.prepend_initial_condition(ic, ic_normed)
+ ic = BatchData.new_on_device(
+ data=ic_data,
+ time=xr.DataArray(np.zeros((3, 1)), dims=["sample", "time"]),
+ ).get_start(
+ prognostic_names=["a", "b"],
+ n_ic_timesteps=1,
+ )
+ prepended = stepped.prepend_initial_condition(ic)
for v in ["a", "b"]:
- assert torch.allclose(prepended.gen_data[v][:, 0], ic[v])
- assert torch.allclose(prepended.gen_data_norm[v][:, 0], ic_normed[v])
- assert torch.allclose(prepended.target_data[v][:, 0], ic[v])
- assert torch.allclose(prepended.target_data_norm[v][:, 0], ic_normed[v])
+ assert torch.allclose(prepended.gen_data[v][:, :1], ic_data[v])
+ assert torch.allclose(prepended.target_data[v][:, :1], ic_data[v])
def test__combine_normalizers():
@@ -752,62 +898,48 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer():
# stepper loaded from state should have the appropriately combined
# full field and residual values in its loss_normalizer
torch.manual_seed(0)
- full_field_normalization = {
- "means": {"a": 0.0, "b": 0.0, "diagnostic": 0.0},
- "stds": {"a": 1.0, "b": 1.0, "diagnostic": 1.0},
- }
+ full_field_means = {"a": 0.0, "b": 0.0, "diagnostic": 0.0}
+ full_field_stds = {"a": 1.0, "b": 1.0, "diagnostic": 1.0}
# residual scalings might have diagnostic variables but the stepper
# should detect which prognostic variables to use from the set
- residual_normalization = {
- "means": {"a": 1.0, "b": 1.0, "diagnostic": 1.0},
- "stds": {"a": 2.0, "b": 2.0, "diagnostic": 2.0},
- }
+ residual_means = {"a": 1.0, "b": 1.0, "diagnostic": 1.0}
+ residual_stds = {"a": 2.0, "b": 2.0, "diagnostic": 2.0}
config = SingleModuleStepperConfig(
builder=ModuleSelector(
type="SphericalFourierNeuralOperatorNet", config={"scale_factor": 1}
),
in_names=["a", "b"],
out_names=["a", "b", "diagnostic"],
- normalization=NormalizationConfig(**full_field_normalization),
- residual_normalization=NormalizationConfig(**residual_normalization),
+ normalization=NormalizationConfig(means=full_field_means, stds=full_field_stds),
+ residual_normalization=NormalizationConfig(
+ means=residual_means, stds=residual_stds
+ ),
)
shapes = {
"a": (1, 1, 5, 5),
"b": (1, 1, 5, 5),
"diagnostic": (1, 1, 5, 5),
}
- area = torch.ones((5, 5), device=fme.get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7))
+ area = torch.ones((5, 5), device=DEVICE)
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ )
orig_stepper = config.get_stepper(
img_shape=shapes["a"][-2:],
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
- stepper_from_state = SingleModuleStepper.from_state(
- orig_stepper.get_state(), area=area, sigma_coordinates=sigma_coordinates
- )
+ stepper_from_state = SingleModuleStepper.from_state(orig_stepper.get_state())
for stepper in [orig_stepper, stepper_from_state]:
- assert stepper.loss_normalizer.means == {"a": 1.0, "b": 1.0, "diagnostic": 0.0}
- assert stepper.loss_normalizer.stds == {"a": 2.0, "b": 2.0, "diagnostic": 1.0}
- assert stepper.normalizer.means == full_field_normalization["means"]
- assert stepper.normalizer.stds == full_field_normalization["stds"]
-
-
-def test_stepper_effective_loss_scaling():
- custom_loss_weights = {"b": 2.0}
- loss_norm_means = {"a": 0.0, "b": 0.0}
- loss_norm_stds = {"a": 4.0, "b": 0.5}
- stepper = _get_stepper(
- in_names=["a", "b"],
- out_names=["a", "b"],
- loss=WeightedMappingLossConfig(weights=custom_loss_weights),
- loss_normalization=NormalizationConfig(
- means=loss_norm_means, stds=loss_norm_stds
- ),
- )
- assert stepper.effective_loss_scaling == {
- "a": torch.tensor(4.0),
- "b": torch.tensor(0.25),
- }
+ assert stepper.loss_normalizer.means == {
+ **residual_means,
+ "diagnostic": full_field_means["diagnostic"],
+ }
+ assert stepper.loss_normalizer.stds == {
+ **residual_stds,
+ "diagnostic": full_field_stds["diagnostic"],
+ }
+ assert stepper.normalizer.means == full_field_means
+ assert stepper.normalizer.stds == full_field_stds
diff --git a/fme/fme/ace/test_train.py b/fme/fme/ace/test_train.py
index 6b2327d..e62c0ee 100755
--- a/fme/fme/ace/test_train.py
+++ b/fme/fme/ace/test_train.py
@@ -1,6 +1,9 @@
+import copy
+import dataclasses
import pathlib
import subprocess
import tempfile
+import textwrap
import unittest.mock
import numpy as np
@@ -10,17 +13,34 @@
import yaml
from fme.ace.inference.evaluator import main as inference_evaluator_main
-from fme.ace.train.train import _restore_checkpoint, count_parameters
-from fme.ace.train.train import main as train_main
-from fme.ace.train.train_config import epoch_checkpoint_enabled
-from fme.core.data_loading.config import Slice
-from fme.core.testing import (
+from fme.ace.registry.test_hpx import (
+ conv_next_block_config,
+ decoder_config,
+ down_sampling_block_config,
+ encoder_config,
+ output_layer_config,
+ recurrent_block_config,
+ up_sampling_block_config,
+)
+from fme.ace.testing import (
DimSizes,
MonthlyReferenceData,
- save_2d_netcdf,
+ save_nd_netcdf,
save_scalar_netcdf,
)
+from fme.ace.train.train import main as train_main
+from fme.core.coordinates import (
+ HEALPixCoordinates,
+ HorizontalCoordinates,
+ LatLonCoordinates,
+)
+from fme.core.generics.trainer import (
+ _restore_checkpoint,
+ count_parameters,
+ epoch_checkpoint_enabled,
+)
from fme.core.testing.wandb import mock_wandb
+from fme.core.typing_ import Slice
REPOSITORY_PATH = pathlib.PurePath(__file__).parent.parent.parent.parent
JOB_SUBMISSION_SCRIPT_PATH = (
@@ -46,7 +66,67 @@ def _get_test_yaml_files(
max_epochs=1,
segment_epochs=1,
inference_forward_steps=2,
+ use_healpix=False,
):
+ input_time_size = 1
+ output_time_size = 1
+ if use_healpix:
+ in_channels = len(in_variable_names)
+ out_channels = len(out_variable_names)
+ prognostic_variables = min(
+ out_channels, in_channels
+ ) # how many variables in/out share.
+ # in practice, we will need to compare variable names, since there
+ # are some input-only and some output-only channels.
+ # TODO: https://github.com/ai2cm/full-model/issues/1046
+ n_constants = 0
+ decoder_input_channels = 0 # was 1, to indicate insolation - now 0
+ input_time_size = 1 # TODO: change to 2 (issue #1177)
+ output_time_size = 1 # TODO: change to 4 (issue #1177)
+ spatial_dimensions_str = "healpix"
+
+ conv_next_block = conv_next_block_config(in_channels=in_channels)
+ down_sampling_block = down_sampling_block_config()
+ recurrent_block = recurrent_block_config()
+ encoder = encoder_config(
+ conv_next_block, down_sampling_block, n_channels=[16, 8, 4]
+ )
+ up_sampling_block = up_sampling_block_config()
+ output_layer = output_layer_config()
+ decoder = decoder_config(
+ conv_next_block,
+ up_sampling_block,
+ output_layer,
+ recurrent_block,
+ n_channels=[4, 8, 16],
+ )
+
+ # Need to manually indent these YAML strings
+ encoder_yaml = yaml.dump(
+ dataclasses.asdict(encoder), default_flow_style=False, indent=2
+ )
+ decoder_yaml = yaml.dump(
+ dataclasses.asdict(decoder), default_flow_style=False, indent=2
+ )
+ encoder_yaml_indented = textwrap.indent(encoder_yaml, " ")
+ decoder_yaml_indented = textwrap.indent(decoder_yaml, " ")
+ config_str = f"""
+ encoder:
+{encoder_yaml_indented}
+ decoder:
+{decoder_yaml_indented}
+ prognostic_variables: {prognostic_variables}
+ n_constants: {n_constants}
+ decoder_input_channels: {decoder_input_channels}
+ input_time_size: {input_time_size}
+ output_time_size: {output_time_size}
+ """
+ else:
+ config_str = """
+ num_layers: 2
+ embed_dim: 12"""
+ spatial_dimensions_str = "latlon"
+
new_stepper_config = f"""
in_names: {in_variable_names}
out_names: {out_variable_names}
@@ -57,12 +137,10 @@ def _get_test_yaml_files(
global_means_path: '{global_means_path}'
global_stds_path: '{global_stds_path}'
loss:
- global_mean_type: "LpLoss"
+ type: "MSE"
builder:
type: {nettype}
- config:
- num_layers: 2
- embed_dim: 12
+ config: {config_str}
ocean:
surface_temperature_name: {in_variable_names[0]}
ocean_fraction_name: {mask_name}
@@ -80,17 +158,18 @@ def _get_test_yaml_files(
train_loader:
dataset:
- data_path: '{train_data_path}'
+ spatial_dimensions: {spatial_dimensions_str}
batch_size: 2
- num_data_workers: 1
+ num_data_workers: 0
validation_loader:
dataset:
- data_path: '{valid_data_path}'
+ spatial_dimensions: {spatial_dimensions_str}
batch_size: 2
- num_data_workers: 1
+ num_data_workers: 0
optimization:
optimizer_type: "Adam"
lr: 0.001
- enable_automatic_mixed_precision: true
scheduler:
type: CosineAnnealingLR
kwargs:
@@ -102,7 +181,8 @@ def _get_test_yaml_files(
monthly_reference_data: {monthly_data_filename}
loader:
dataset:
- data_path: '{valid_data_path}'
+ data_path: '{valid_data_path}'
+ spatial_dimensions: {spatial_dimensions_str}
start_indices:
first: 0
n_initial_conditions: 2
@@ -139,6 +219,7 @@ def _get_test_yaml_files(
loader:
dataset:
data_path: '{valid_data_path}'
+ spatial_dimensions: {spatial_dimensions_str}
start_indices:
first: 0
n_initial_conditions: 2
@@ -156,6 +237,23 @@ def _get_test_yaml_files(
return f_train.name, f_inference.name
+def get_sizes(
+ spatial_dims: HorizontalCoordinates = LatLonCoordinates(
+ lon=torch.Tensor(np.arange(32)),
+ lat=torch.Tensor(np.arange(16)),
+ loaded_lat_name="grid_yt",
+ loaded_lon_name="grid_xt",
+ ),
+ n_time=3,
+ nz_interface=3,
+) -> DimSizes:
+ return DimSizes(
+ n_time=n_time,
+ horizontal=copy.deepcopy(spatial_dims.loaded_sizes),
+ nz_interface=nz_interface,
+ )
+
+
def _setup(
path,
nettype,
@@ -165,22 +263,31 @@ def _setup(
n_time=10,
timestep_days=5,
inference_forward_steps=2,
+ use_healpix=False,
):
if not path.exists():
path.mkdir()
seed = 0
np.random.seed(seed)
- in_variable_names = ["foo", "bar", "baz"]
- out_variable_names = ["foo", "bar"]
+ in_variable_names = [
+ "PRESsfc",
+ "specific_total_water_0",
+ "specific_total_water_1",
+ "baz",
+ ]
+ out_variable_names = ["PRESsfc", "specific_total_water_0", "specific_total_water_1"]
mask_name = "mask"
all_variable_names = list(set(in_variable_names + out_variable_names))
- dim_sizes = DimSizes(
- n_time=n_time,
- n_lat=16,
- n_lon=32,
- nz_interface=7,
- )
+ if use_healpix:
+ hpx_coords = HEALPixCoordinates(
+ face=torch.Tensor(np.arange(12)),
+ width=torch.Tensor(np.arange(16)),
+ height=torch.Tensor(np.arange(16)),
+ )
+ dim_sizes = get_sizes(spatial_dims=hpx_coords, n_time=n_time)
+ else:
+ dim_sizes = get_sizes(n_time=n_time)
data_dir = path / "data"
stats_dir = path / "stats"
@@ -188,7 +295,7 @@ def _setup(
data_dir.mkdir()
stats_dir.mkdir()
results_dir.mkdir()
- save_2d_netcdf(
+ save_nd_netcdf(
data_dir / "data.nc",
dim_sizes,
variable_names=all_variable_names + [mask_name],
@@ -203,15 +310,22 @@ def _setup(
variable_names=all_variable_names,
)
+ monthly_dim_sizes: DimSizes
+ if use_healpix:
+ hpx_coords = HEALPixCoordinates(
+ face=torch.Tensor(np.arange(12)),
+ width=torch.Tensor(np.arange(16)),
+ height=torch.Tensor(np.arange(16)),
+ )
+ monthly_dim_sizes = get_sizes(
+ spatial_dims=hpx_coords, n_time=10 * 12, nz_interface=1
+ )
+ else:
+ monthly_dim_sizes = get_sizes(n_time=10 * 12, nz_interface=1)
monthly_reference_data = MonthlyReferenceData(
path=data_dir,
names=out_variable_names,
- dim_sizes=DimSizes(
- n_time=10 * 12,
- n_lat=16,
- n_lon=32,
- nz_interface=1,
- ),
+ dim_sizes=monthly_dim_sizes,
n_ensemble=3,
)
@@ -230,61 +344,84 @@ def _setup(
max_epochs=max_epochs,
segment_epochs=segment_epochs,
inference_forward_steps=inference_forward_steps,
+ use_healpix=use_healpix,
)
return train_config_filename, inference_config_filename
@pytest.mark.parametrize(
- "nettype", ["SphericalFourierNeuralOperatorNet", "SFNO-v0.1.0"]
+ "nettype", ["SphericalFourierNeuralOperatorNet", "HEALPixRecUNet", "SFNO-v0.1.0"]
)
-def test_train_and_inference_inline(tmp_path, nettype):
+def test_train_and_inference_inline(tmp_path, nettype, very_fast_only: bool):
"""Make sure that training and inference run without errors
Args:
tmp_path: pytext fixture for temporary workspace.
nettype: parameter indicating model architecture to use.
- debug: option for developers to allow use of pdb.
+ very_fast_only: parameter indicating whether to skip slow tests.
"""
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
# need multi-year to cover annual aggregator
train_config, inference_config = _setup(
tmp_path,
nettype,
+ log_to_wandb=True,
timestep_days=20,
n_time=int(366 * 3 / 20 + 1),
inference_forward_steps=int(366 * 3 / 20 / 2 - 1) * 2, # must be even
+ use_healpix=(nettype == "HEALPixRecUNet"),
)
# using pdb requires calling main functions directly
- train_main(
- yaml_config=train_config,
- )
+ with mock_wandb() as wandb:
+ train_main(
+ yaml_config=train_config,
+ )
+ wandb_logs = wandb.get_logs()
+
+ for log in wandb_logs:
+ # ensure inference time series is not logged
+ assert "inference/mean/forecast_step" not in log
+
# inference should not require stats files
(tmp_path / "stats" / "stats-mean.nc").unlink()
(tmp_path / "stats" / "stats-stddev.nc").unlink()
- inference_logs = inference_evaluator_main(yaml_config=inference_config)
- assert len(inference_logs) == 7 # 6 forward steps + 1 initial state
+
+ with mock_wandb() as wandb:
+ wandb.configure(log_to_wandb=True)
+ inference_evaluator_main(yaml_config=inference_config)
+ inference_logs = wandb.get_logs()
+
prediction_output_path = tmp_path / "output" / "autoregressive_predictions.nc"
- assert prediction_output_path.exists()
best_checkpoint_path = (
tmp_path / "output" / "training_checkpoints" / "best_ckpt.tar"
)
- assert best_checkpoint_path.exists()
best_inference_checkpoint_path = (
tmp_path / "output" / "training_checkpoints" / "best_inference_ckpt.tar"
)
+ assert best_checkpoint_path.exists()
assert best_inference_checkpoint_path.exists()
+ n_ic_timesteps = 1
+ n_forward_steps = 6
+ n_summary_steps = 1
+ assert len(inference_logs) == n_ic_timesteps + n_forward_steps + n_summary_steps
+ assert prediction_output_path.exists()
ds_prediction = xr.open_dataset(prediction_output_path)
- assert np.sum(np.isnan(ds_prediction["foo"].values)) == 0
- assert np.sum(np.isnan(ds_prediction["bar"].values)) == 0
+ assert np.sum(np.isnan(ds_prediction["PRESsfc"].values)) == 0
+ assert np.sum(np.isnan(ds_prediction["specific_total_water_0"].values)) == 0
+ assert np.sum(np.isnan(ds_prediction["specific_total_water_1"].values)) == 0
ds_target = xr.open_dataset(tmp_path / "output" / "autoregressive_target.nc")
assert np.sum(np.isnan(ds_target["baz"].values)) == 0
@pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"])
-def test_resume(tmp_path, nettype):
+def test_resume(tmp_path, nettype, very_fast_only: bool):
"""Make sure the training is resumed from a checkpoint when restarted."""
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
mock = unittest.mock.MagicMock(side_effect=_restore_checkpoint)
- with unittest.mock.patch("fme.ace.train.train._restore_checkpoint", new=mock):
+ with unittest.mock.patch("fme.core.generics.trainer._restore_checkpoint", new=mock):
train_config, _ = _setup(
tmp_path, nettype, log_to_wandb=True, max_epochs=2, segment_epochs=1
)
@@ -292,28 +429,16 @@ def test_resume(tmp_path, nettype):
train_main(
yaml_config=train_config,
)
- assert (
- min([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val])
- == 0
- )
- assert (
- max([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val])
- == 0
- )
+ assert min([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 0
+ assert max([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 0
assert not mock.called
with mock_wandb() as wandb:
train_main(
yaml_config=train_config,
)
mock.assert_called()
- assert (
- min([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val])
- == 1
- )
- assert (
- max([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val])
- == 1
- )
+ assert min([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 1
+ assert max([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 1
@pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"])
@@ -349,29 +474,33 @@ def _create_fine_tuning_config(path_to_train_config_yaml: str, path_to_checkpoin
with open(path_to_train_config_yaml, "r") as config_file:
config_data = yaml.safe_load(config_file)
config_data["stepper"] = {"checkpoint_path": path_to_checkpoint}
+ current_experiment_dir = config_data["experiment_dir"]
+ new_experiment_dir = pathlib.Path(current_experiment_dir) / "fine_tuning"
+ config_data["experiment_dir"] = str(new_experiment_dir)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".yaml"
) as new_config_file:
new_config_file.write(yaml.dump(config_data))
- return new_config_file.name
+ return new_config_file.name, new_experiment_dir
@pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"])
-def test_fine_tuning(tmp_path, nettype):
+def test_fine_tuning(tmp_path, nettype, very_fast_only: bool):
"""Check that fine tuning config runs without errors."""
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
train_config, _ = _setup(tmp_path, nettype)
- train_main(
- yaml_config=train_config,
- )
+ train_main(yaml_config=train_config)
results_dir = tmp_path / "output"
ckpt = f"{results_dir}/training_checkpoints/best_ckpt.tar"
- fine_tuning_config = _create_fine_tuning_config(train_config, ckpt)
+ fine_tuning_config, new_results_dir = _create_fine_tuning_config(train_config, ckpt)
train_main(yaml_config=fine_tuning_config)
+ assert (new_results_dir / "training_checkpoints" / "ckpt.tar").exists()
def _create_copy_weights_after_batch_config(
diff --git a/fme/fme/ace/testing/__init__.py b/fme/fme/ace/testing/__init__.py
new file mode 100644
index 0000000..034a41a
--- /dev/null
+++ b/fme/fme/ace/testing/__init__.py
@@ -0,0 +1,9 @@
+from .fv3gfs_data import (
+ DimSize,
+ DimSizes,
+ FV3GFSData,
+ MonthlyReferenceData,
+ StatsData,
+ save_nd_netcdf,
+ save_scalar_netcdf,
+)
diff --git a/fme/fme/core/testing/fv3gfs_data.py b/fme/fme/ace/testing/fv3gfs_data.py
similarity index 71%
rename from fme/fme/core/testing/fv3gfs_data.py
rename to fme/fme/ace/testing/fv3gfs_data.py
index fede51b..3b20d42 100644
--- a/fme/fme/core/testing/fv3gfs_data.py
+++ b/fme/fme/ace/testing/fv3gfs_data.py
@@ -1,17 +1,18 @@
import dataclasses
import datetime
import pathlib
-from typing import List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple
import cftime
import numpy as np
import xarray as xr
-from fme.core.data_loading.config import XarrayDataConfig
-from fme.core.data_loading.inference import (
+from fme.ace.data_loading.inference import (
InferenceDataLoaderConfig,
InferenceInitialConditionIndices,
)
+from fme.core.coordinates import DimSize
+from fme.core.dataset.config import XarrayDataConfig
def _coord_value(
@@ -30,51 +31,59 @@ def _coord_value(
@dataclasses.dataclass
class DimSizes:
n_time: int
- n_lat: int
- n_lon: int
+ horizontal: List[DimSize]
nz_interface: int
@property
- def shape_2d(self) -> Tuple[int, int, int]:
- return (self.n_time, self.n_lat, self.n_lon)
+ def shape_nd(self) -> List[int]:
+ return [self.n_time] + [dim.size for dim in self.horizontal]
@property
- def dims_2d(self) -> Tuple[str, str, str]:
- return ("time", "grid_yt", "grid_xt")
+ def dims_nd(self) -> List[str]:
+ return ["time"] + [dim.name for dim in self.horizontal]
@property
def shape_vertical_interface(self) -> Tuple[int]:
return (self.nz_interface,)
+ @property
def items(self):
- return [
- ("time", self.n_time),
- ("grid_yt", self.n_lat),
- ("grid_xt", self.n_lon),
+ return [("time", self.n_time)] + [
+ (dim.name, dim.size) for dim in self.horizontal
]
+ def get_size(self, key: str) -> int:
+ for item in self.horizontal:
+ if item.name == key:
+ return item.size
+ raise KeyError(f"Dimension with name '{key}' not found.")
-def save_2d_netcdf(
+
+def save_nd_netcdf(
filename,
dim_sizes: DimSizes,
variable_names: List[str],
timestep_days: float,
time_varying_values: Optional[List[float]] = None,
+ save_vertical_coordinate: bool = True,
):
"""
- Save a 2D netcdf file with random data for the given variable names and
+ Save a ND netcdf file with random data for the given variable names and
dimensions.
Args:
filename: The filename to save the netcdf file to.
dim_sizes: The dimensions of the data.
variable_names: The names of the variables to save.
+ timestep_days: The number of days between each time step.
time_varying_values: If not None, the values to use for each time step.
+ save_vertical_coordinate: If True, save vertical coordinate variables.
"""
- ds = get_2d_dataset(
+ ds = get_nd_dataset(
dim_sizes=dim_sizes,
variable_names=variable_names,
timestep_days=timestep_days,
+ include_vertical_coordinate=save_vertical_coordinate,
)
if time_varying_values is not None:
for name in variable_names:
@@ -122,6 +131,8 @@ class FV3GFSData:
dim_sizes: DimSizes
timestep_days: float
time_varying_values: Optional[List[float]] = None
+ num_data_workers: int = 0
+ save_vertical_coordinate: bool = True
def __post_init__(self):
self.data_path.mkdir(parents=True, exist_ok=True)
@@ -133,12 +144,13 @@ def __post_init__(self):
f"Number of time-varying values ({len(self.time_varying_values)}) "
f"must match number of time steps ({self.dim_sizes.n_time})"
)
- save_2d_netcdf(
- self._data_filename,
+ save_nd_netcdf(
+ self.data_filename,
dim_sizes=self.dim_sizes,
variable_names=self.names,
timestep_days=self.timestep_days,
time_varying_values=self.time_varying_values,
+ save_vertical_coordinate=self.save_vertical_coordinate,
)
@property
@@ -147,7 +159,7 @@ def data_path(self):
return self.path / "data"
@property
- def _data_filename(self):
+ def data_filename(self):
return self.data_path / "data.nc"
@property
@@ -159,7 +171,7 @@ def inference_data_loader_config(self) -> InferenceDataLoaderConfig:
start_indices=InferenceInitialConditionIndices(
first=0, n_initial_conditions=1, interval=1
),
- num_data_workers=2,
+ num_data_workers=self.num_data_workers,
)
@@ -172,7 +184,7 @@ class MonthlyReferenceData:
def __post_init__(self):
self.data_path.mkdir(parents=True, exist_ok=True)
- ds = get_2d_dataset(
+ ds = get_nd_dataset(
dim_sizes=self.dim_sizes,
variable_names=self.names,
)
@@ -207,33 +219,43 @@ def data_filename(self):
return self.data_path / "monthly.nc"
-def get_2d_dataset(
+def get_nd_overrides(dim_sizes: DimSizes, timestep_days: float) -> dict[str, Any]:
+ coords_override: dict[str, Any]
+ time = [
+ cftime.DatetimeProlepticGregorian(2000, 1, 1)
+ + i * timestep_days * datetime.timedelta(days=1)
+ for i in range(dim_sizes.n_time)
+ ]
+ coords_override = {"time": time}
+ if "grid_yt" in dim_sizes.dims_nd:
+ n_lat = dim_sizes.get_size("grid_yt") # n_lat
+ n_lon = dim_sizes.get_size("grid_xt") # n_lon
+ grid_yt = np.linspace(-89.5, 89.5, n_lat)
+ grid_xt_start = 360.0 / n_lon / 2
+ grid_xt = np.linspace(grid_xt_start, 360.0 - grid_xt_start, n_lon)
+ coords_override["grid_yt"] = grid_yt
+ coords_override["grid_xt"] = grid_xt
+ return coords_override
+
+
+def get_nd_dataset(
dim_sizes: DimSizes,
variable_names: Sequence[str],
timestep_days: float = 1.0,
+ include_vertical_coordinate: bool = True,
):
"""
- Gets a dataset of [time, lat, lon] data.
+ Gets a dataset of [time, ] data.
"""
data_vars = {}
for name in variable_names:
- data = np.random.randn(*dim_sizes.shape_2d).astype(np.float32)
+ data = np.random.randn(*dim_sizes.shape_nd).astype(np.float32)
data_vars[name] = xr.DataArray(
data,
- dims=["time", "grid_yt", "grid_xt"],
+ dims=dim_sizes.dims_nd,
attrs={"units": "m", "long_name": name},
)
-
- grid_yt = np.linspace(-89.5, 89.5, dim_sizes.n_lat)
- grid_xt_start = 360.0 / dim_sizes.n_lon / 2
- grid_xt = np.linspace(grid_xt_start, 360.0 - grid_xt_start, dim_sizes.n_lon)
- time = [
- cftime.DatetimeProlepticGregorian(2000, 1, 1)
- + i * timestep_days * datetime.timedelta(days=1)
- for i in range(dim_sizes.n_time)
- ]
-
- coords_override = {"grid_yt": grid_yt, "grid_xt": grid_xt, "time": time}
+ coords_override = get_nd_overrides(dim_sizes=dim_sizes, timestep_days=timestep_days)
coords = {
dim_name: (
@@ -244,12 +266,13 @@ def get_2d_dataset(
if dim_name not in coords_override
else coords_override[dim_name]
)
- for dim_name, size in dim_sizes.items()
+ for dim_name, size in dim_sizes.items
}
- for i in range(dim_sizes.nz_interface):
- data_vars[f"ak_{i}"] = np.float64(i)
- data_vars[f"bk_{i}"] = np.float64(i + 1)
+ if include_vertical_coordinate:
+ for i in range(dim_sizes.nz_interface):
+ data_vars[f"ak_{i}"] = np.float64(i)
+ data_vars[f"bk_{i}"] = np.float64(i + 1)
ds = xr.Dataset(data_vars=data_vars, coords=coords)
return ds
diff --git a/fme/fme/ace/train/__init__.py b/fme/fme/ace/train/__init__.py
index 35d2319..d95e567 100644
--- a/fme/fme/ace/train/__init__.py
+++ b/fme/fme/ace/train/__init__.py
@@ -1 +1,2 @@
-from fme.ace.train.train import Trainer, _restore_checkpoint, count_parameters
+from fme.ace.train.train import Trainer
+from fme.core.generics.trainer import count_parameters
diff --git a/fme/fme/ace/train/train.py b/fme/fme/ace/train/train.py
index aecf943..15355ec 100644
--- a/fme/fme/ace/train/train.py
+++ b/fme/fme/ace/train/train.py
@@ -48,484 +48,171 @@
# Karthik Kashinath - NVIDIA Corporation
# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
-import contextlib
import dataclasses
-import gc
import logging
import os
-import time
-import uuid
-from typing import Optional
+from datetime import timedelta
+from typing import Callable, Dict, Mapping, Optional, Sequence
import dacite
import dask
import torch
+import xarray as xr
import yaml
import fme
import fme.core.logging_utils as logging_utils
-from fme.ace.inference import run_inference_evaluator
-from fme.ace.inference.derived_variables import compute_stepped_derived_quantities
-from fme.ace.train.train_config import TrainConfig
-from fme.core.aggregator import (
+from fme.ace.aggregator import OneStepAggregator, TrainAggregator
+from fme.ace.aggregator.inference.main import (
+ InferenceEvaluatorAggregator,
InferenceEvaluatorAggregatorConfig,
- OneStepAggregator,
- TrainAggregator,
)
-from fme.core.data_loading.getters import get_data_loader, get_inference_data
-from fme.core.data_loading.utils import BatchData
+from fme.ace.data_loading.batch_data import PairedData, PrognosticState
+from fme.ace.stepper import TrainOutput
+from fme.ace.train.train_config import TrainBuilders, TrainConfig
+from fme.core.coordinates import HorizontalCoordinates, HybridSigmaPressureCoordinate
+from fme.core.dataset.data_typing import VariableMetadata
+from fme.core.dicts import to_flat_dict
from fme.core.distributed import Distributed
-from fme.core.ema import EMATracker
-from fme.core.optimization import NullOptimization
-from fme.core.stepper import SingleModuleStepper
-from fme.core.wandb import WandB
+from fme.core.generics.trainer import AggregatorBuilderABC, TrainConfigProtocol, Trainer
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.typing_ import TensorDict, TensorMapping
# dask used on individual workers to load batches
dask.config.set(scheduler="synchronous")
-def count_parameters(modules: torch.nn.ModuleList) -> int:
- parameters = 0
- for module in modules:
- for parameter in module.parameters():
- if parameter.requires_grad:
- parameters += parameter.numel()
- return parameters
+def build_trainer(builder: TrainBuilders, config: TrainConfig) -> "Trainer":
+ # note for devs: you don't have to use this function to build a custom
+ # trainer, you can build it however you like. This is here for convenience.
+ train_data = builder.get_train_data()
+ validation_data = builder.get_validation_data()
+ inference_data = builder.get_evaluation_inference_data()
-
-class Trainer:
- def __init__(self, config: TrainConfig):
- logging.info(f"Current device is {fme.get_device()}")
- self.dist = Distributed.get_instance()
- if self.dist.is_root():
- if not os.path.isdir(config.experiment_dir):
- os.makedirs(config.experiment_dir)
- if not os.path.isdir(config.checkpoint_dir):
- os.makedirs(config.checkpoint_dir)
- self.config = config
-
- data_requirements = config.stepper.get_data_requirements(
- n_forward_steps=self.config.n_forward_steps
- )
- logging.info("rank %d, begin data loader init" % self.dist.rank)
- self.train_data = get_data_loader(
- config.train_loader,
- requirements=data_requirements,
- train=True,
- )
- self.valid_data = get_data_loader(
- config.validation_loader,
- requirements=data_requirements,
- train=False,
- )
- logging.info("rank %d, data loader initialized" % self.dist.rank)
- for gridded_data, name in zip(
- (self.train_data, self.valid_data), ("train", "valid")
- ):
- n_samples = len(gridded_data.loader.dataset)
- n_batches = len(gridded_data.loader)
- logging.info(f"{name} data: {n_samples} samples, {n_batches} batches")
- first_time = gridded_data.loader.dataset[0][1].values[0]
- last_time = gridded_data.loader.dataset[-1][1].values[0]
- logging.info(f"{name} data: first sample's initial time: {first_time}")
- logging.info(f"{name} data: last sample's initial time: {last_time}")
-
- self.num_batches_seen = 0
- self.startEpoch = 0
- self._model_epoch = self.startEpoch
- self.num_batches_seen = 0
- self._best_validation_loss = torch.inf
- self._best_inference_error = torch.inf
-
- for batch in self.train_data.loader:
- shapes = {k: v.shape for k, v in batch.data.items()}
- for value in shapes.values():
- img_shape = value[-2:]
- break
+ for batch in train_data.loader:
+ shapes = {k: v.shape for k, v in batch.data.items()}
+ for value in shapes.values():
+ img_shape = value[-2:]
break
- logging.info("Starting model initialization")
- self.stepper = config.stepper.get_stepper(
- img_shape=img_shape,
- area=self.train_data.area_weights,
- sigma_coordinates=self.train_data.sigma_coordinates,
- timestep=self.train_data.timestep,
- )
- self.optimization = config.optimization.build(
- self.stepper.module.parameters(), config.max_epochs
- )
- self._base_weights = self.config.stepper.get_base_weights()
- self._copy_after_batch = config.copy_weights_after_batch
- self._no_optimization = NullOptimization()
-
- if config.resuming:
- logging.info("Loading checkpoint %s" % config.latest_checkpoint_path)
- self.restore_checkpoint(
- config.latest_checkpoint_path, config.ema_checkpoint_path
- )
-
- wandb = WandB.get_instance()
- wandb.watch(self.stepper.modules)
-
- logging.info(
- (
- "Number of trainable model parameters: "
- f"{count_parameters(self.stepper.modules)}"
- )
- )
- inference_data_requirements = dataclasses.replace(data_requirements)
- inference_data_requirements.n_timesteps = config.inference.n_forward_steps + 1
-
- self._inference_data = get_inference_data(
- config.inference.loader,
- config.inference.forward_steps_in_memory,
- inference_data_requirements,
- )
-
- self._ema = self.config.ema.build(self.stepper.modules)
-
- def switch_off_grad(self, model):
- for param in model.parameters():
- param.requires_grad = False
-
- def train(self):
- logging.info("Starting Training Loop...")
-
- self._model_epoch = self.startEpoch
- inference_epochs = list(range(0, self.config.max_epochs))[
- self.config.inference.epochs.slice
- ]
- if self.config.segment_epochs is None:
- segment_max_epochs = self.config.max_epochs
- else:
- segment_max_epochs = min(
- self.startEpoch + self.config.segment_epochs, self.config.max_epochs
- )
- # "epoch" describes the loop, self._model_epoch describes model weights
- # needed so we can describe the loop even after weights are updated
- for epoch in range(self.startEpoch, segment_max_epochs):
- # garbage collect to avoid CUDA error in some contexts
- # https://github.com/pytorch/pytorch/issues/67978#issuecomment-1661986812 # noqa: E501
- gc.collect()
- logging.info(f"Epoch: {epoch+1}")
- if isinstance(self.train_data.sampler, torch.utils.data.DistributedSampler):
- self.train_data.sampler.set_epoch(epoch)
-
- start_time = time.time()
- logging.info(f"Starting training step on epoch {epoch + 1}")
- train_logs = self.train_one_epoch()
- train_end = time.time()
- logging.info(f"Starting validation step on epoch {epoch + 1}")
- valid_logs = self.validate_one_epoch()
- valid_end = time.time()
- if epoch in inference_epochs:
- logging.info(f"Starting inference step on epoch {epoch + 1}")
- inference_logs = self.inference_one_epoch()
- inference_end: Optional[float] = time.time()
- else:
- inference_logs = {}
- inference_end = None
-
- train_loss = train_logs["train/mean/loss"]
- valid_loss = valid_logs["val/mean/loss"]
- inference_error = inference_logs.get(
- "inference/time_mean_norm/rmse/channel_mean", None
- )
- # need to get the learning rate before stepping the scheduler
- lr = self.optimization.learning_rate
- self.optimization.step_scheduler(valid_loss)
-
- if self.dist.is_root():
- if self.config.save_checkpoint:
- logging.info(f"Saving checkpoints for epoch {epoch + 1}")
- self.save_all_checkpoints(valid_loss, inference_error)
-
- time_elapsed = time.time() - start_time
- logging.info(f"Time taken for epoch {epoch + 1} is {time_elapsed} sec")
- logging.info(f"Train loss: {train_loss}. Valid loss: {valid_loss}")
- if inference_error is not None:
- logging.info(f"Inference error: {inference_error}")
-
- logging.info("Logging to wandb")
- all_logs = {
- **train_logs,
- **valid_logs,
- **inference_logs,
- **{
- "lr": lr,
- "epoch": epoch,
- "epoch_train_seconds": train_end - start_time,
- "epoch_validation_seconds": valid_end - train_end,
- "epoch_total_seconds": time_elapsed,
- },
- }
- if inference_end is not None:
- all_logs["epoch_inference_seconds"] = inference_end - valid_end
- wandb = WandB.get_instance()
- wandb.log(all_logs, step=self.num_batches_seen)
- if segment_max_epochs == self.config.max_epochs:
- self.config.clean_wandb()
-
- def train_one_epoch(self):
- """Train for one epoch and return logs from TrainAggregator."""
- wandb = WandB.get_instance()
- aggregator = TrainAggregator()
- if self.num_batches_seen == 0:
- # Before training, log the loss on the first batch.
- with torch.no_grad():
- batch = next(iter(self.train_data.loader))
- stepped = self.stepper.run_on_batch(
- batch.data,
- optimization=self._no_optimization,
- n_forward_steps=self.config.n_forward_steps,
- )
-
- if self.config.log_train_every_n_batches > 0:
- with torch.no_grad():
- metrics = {
- f"batch_{name}": self.dist.reduce_mean(metric)
- for name, metric in sorted(stepped.metrics.items())
- }
- wandb.log(metrics, step=self.num_batches_seen)
- batch: BatchData
- current_time = time.time()
- for batch in self.train_data.loader:
- stepped = self.stepper.run_on_batch(
- batch.data,
- self.optimization,
- n_forward_steps=self.config.n_forward_steps,
- )
- aggregator.record_batch(stepped.metrics["loss"])
- if self._base_weights is not None:
- self._copy_after_batch.apply(
- weights=self._base_weights, modules=self.stepper.modules
- )
- self._ema(model=self.stepper.modules)
- self.num_batches_seen += 1
- if (
- self.config.log_train_every_n_batches > 0
- and self.num_batches_seen % self.config.log_train_every_n_batches == 0
- ):
- with torch.no_grad():
- metrics = {
- f"batch_{name}": self.dist.reduce_mean(metric)
- for name, metric in sorted(stepped.metrics.items())
- }
- duration = time.time() - current_time
- current_time = time.time()
- n_samples = (
- self.train_data.loader.batch_size
- * self.config.log_train_every_n_batches
- )
- samples_per_second = n_samples / duration
- metrics["training_samples_per_second"] = samples_per_second
- wandb.log(metrics, step=self.num_batches_seen)
- self._model_epoch += 1
-
- return aggregator.get_logs(label="train")
-
- @contextlib.contextmanager
- def _validation_context(self):
- """
- The context for running validation.
-
- In this context, the stepper uses the EMA model if
- `self.config.validate_using_ema` is True.
- """
- if self.config.validate_using_ema:
- with self._ema_context():
- yield
- else:
- yield
-
- @contextlib.contextmanager
- def _ema_context(self):
- """
- A context where the stepper uses the EMA model.
- """
- self._ema.store(parameters=self.stepper.modules.parameters())
- self._ema.copy_to(model=self.stepper.modules)
- try:
- yield
- finally:
- self._ema.restore(parameters=self.stepper.modules.parameters())
-
- def validate_one_epoch(self):
- aggregator = OneStepAggregator(
- self.train_data.area_weights.to(fme.get_device()),
- self.train_data.sigma_coordinates,
- self.train_data.metadata,
- loss_scaling=self.stepper.effective_loss_scaling,
- )
-
- with torch.no_grad(), self._validation_context():
- for batch in self.valid_data.loader:
- stepped = self.stepper.run_on_batch(
- batch.data,
- optimization=NullOptimization(),
- n_forward_steps=self.config.n_forward_steps,
- )
- # Prepend initial condition back to start of windows
- # as it's used to compute differenced quantities
- ic, normed_ic = self.stepper.get_initial_condition(batch.data)
- stepped = stepped.prepend_initial_condition(ic, normed_ic)
-
- stepped = compute_stepped_derived_quantities(
- stepped,
- self.valid_data.sigma_coordinates,
- self.valid_data.timestep,
- forcing_data=stepped.target_data,
- )
- aggregator.record_batch(
- loss=stepped.metrics["loss"],
- target_data=stepped.target_data,
- gen_data=stepped.gen_data,
- target_data_norm=stepped.target_data_norm,
- gen_data_norm=stepped.gen_data_norm,
- )
- return aggregator.get_logs(label="val")
+ break
+ logging.info("Starting model initialization")
+ stepper = builder.get_stepper(
+ img_shape=img_shape,
+ gridded_operations=train_data.gridded_operations,
+ vertical_coordinate=train_data.vertical_coordinate,
+ timestep=train_data.timestep,
+ )
+ end_of_batch_ops = builder.get_end_of_batch_ops(stepper.modules)
+
+ for batch in inference_data.loader:
+ initial_inference_times = batch.time.isel(time=0)
+ break
+ aggregator_builder = AggregatorBuilder(
+ inference_config=config.inference_aggregator,
+ gridded_operations=train_data.gridded_operations,
+ vertical_coordinate=train_data.vertical_coordinate,
+ horizontal_coordinates=train_data.horizontal_coordinates,
+ timestep=train_data.timestep,
+ initial_inference_time=initial_inference_times,
+ record_step_20=config.inference_n_forward_steps >= 20,
+ n_timesteps=config.inference_n_forward_steps + stepper.n_ic_timesteps,
+ variable_metadata=train_data.variable_metadata,
+ loss_scaling=stepper.effective_loss_scaling,
+ channel_mean_names=stepper.out_names,
+ normalize=stepper.normalizer.normalize,
+ )
+ do_gc_collect = fme.get_device() != torch.device("cpu")
+ trainer_config: TrainConfigProtocol = config # documenting trainer input type
+ return Trainer(
+ train_data=train_data,
+ validation_data=validation_data,
+ inference_data=inference_data,
+ stepper=stepper,
+ build_optimization=builder.get_optimization,
+ build_ema=builder.get_ema,
+ config=trainer_config,
+ aggregator_builder=aggregator_builder,
+ end_of_batch_callback=end_of_batch_ops,
+ do_gc_collect=do_gc_collect,
+ )
- def inference_one_epoch(self):
- record_step_20 = self.config.inference.n_forward_steps >= 20
- aggregator_config: InferenceEvaluatorAggregatorConfig = (
- self.config.inference.aggregator
- )
- for batch in self._inference_data.loader:
- initial_times = batch.times.isel(time=0)
- break
- aggregator = aggregator_config.build(
- area_weights=self.train_data.area_weights.to(fme.get_device()),
- sigma_coordinates=self.train_data.sigma_coordinates,
- timestep=self.train_data.timestep,
- initial_times=initial_times,
- record_step_20=record_step_20,
- n_timesteps=self.config.inference.n_forward_steps + 1,
- metadata=self.train_data.metadata,
- data_grid=self.train_data.grid,
- )
- with torch.no_grad(), self._validation_context():
- run_inference_evaluator(
- aggregator=aggregator,
- stepper=self.stepper,
- data=self._inference_data,
- )
- logs = aggregator.get_logs(label="inference")
- if "inference/mean/series" in logs:
- # Tables don't work well when reported every epoch, this is a quick
- # workaround to remove them. Could refactor to avoid returning
- # at all, but it's used when converting the logs to epoch-wise
- # wandb logs in standalone inference.
- logs.pop("inference/mean/series")
- if "inference/mean_norm/series" in logs:
- logs.pop("inference/mean_norm/series")
- return logs
- def save_checkpoint(self, checkpoint_path):
- # save to a temporary file in case we get pre-empted during save
- temporary_location = os.path.join(
- os.path.dirname(checkpoint_path), f".{uuid.uuid4()}.tmp"
+class AggregatorBuilder(
+ AggregatorBuilderABC[PrognosticState, TrainOutput, PairedData],
+):
+ def __init__(
+ self,
+ inference_config: InferenceEvaluatorAggregatorConfig,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ horizontal_coordinates: HorizontalCoordinates,
+ timestep: timedelta,
+ initial_inference_time: xr.DataArray,
+ record_step_20: bool,
+ n_timesteps: int,
+ normalize: Callable[[TensorMapping], TensorDict],
+ variable_metadata: Optional[Mapping[str, VariableMetadata]] = None,
+ loss_scaling: Optional[Dict[str, torch.Tensor]] = None,
+ channel_mean_names: Optional[Sequence[str]] = None,
+ ):
+ self.inference_config = inference_config
+ self.gridded_operations = gridded_operations
+ self.vertical_coordinate = vertical_coordinate
+ self.horizontal_coordinates = horizontal_coordinates
+ self.timestep = timestep
+ self.initial_inference_time = initial_inference_time
+ self.record_step_20 = record_step_20
+ self.n_timesteps = n_timesteps
+ self.variable_metadata = variable_metadata
+ self.loss_scaling = loss_scaling
+ self.channel_mean_names = channel_mean_names
+ self.normalize = normalize
+
+ def get_train_aggregator(self) -> TrainAggregator:
+ return TrainAggregator()
+
+ def get_validation_aggregator(self) -> OneStepAggregator:
+ return OneStepAggregator(
+ gridded_operations=self.gridded_operations,
+ variable_metadata=self.variable_metadata,
+ loss_scaling=self.loss_scaling,
)
- try:
- torch.save(
- {
- "num_batches_seen": self.num_batches_seen,
- "epoch": self._model_epoch,
- "best_validation_loss": self._best_validation_loss,
- "best_inference_error": self._best_inference_error,
- "stepper": self.stepper.get_state(),
- "optimization": self.optimization.get_state(),
- "ema": self._ema.get_state(),
- },
- temporary_location,
- )
- os.replace(temporary_location, checkpoint_path)
- finally:
- if os.path.exists(temporary_location):
- os.remove(temporary_location)
- def restore_checkpoint(self, checkpoint_path, ema_checkpoint_path):
- _restore_checkpoint(self, checkpoint_path, ema_checkpoint_path)
-
- def save_all_checkpoints(self, valid_loss: float, inference_error: Optional[float]):
- logging.info(
- f"Saving latest checkpoint to {self.config.latest_checkpoint_path}"
+ def get_inference_aggregator(
+ self,
+ ) -> InferenceEvaluatorAggregator:
+ return self.inference_config.build(
+ vertical_coordinate=self.vertical_coordinate,
+ horizontal_coordinates=self.horizontal_coordinates,
+ timestep=self.timestep,
+ initial_time=self.initial_inference_time,
+ record_step_20=self.record_step_20,
+ n_timesteps=self.n_timesteps,
+ variable_metadata=self.variable_metadata,
+ channel_mean_names=self.channel_mean_names,
+ normalize=self.normalize,
)
- self.save_checkpoint(self.config.latest_checkpoint_path)
- if self.config.epoch_checkpoint_enabled(self._model_epoch):
- epoch_checkpoint_path = self.config.epoch_checkpoint_path(self._model_epoch)
- logging.info(f"Saving epoch checkpoint to {epoch_checkpoint_path}")
- self.save_checkpoint(epoch_checkpoint_path)
- if self.config.ema_epoch_checkpoint_enabled(self._model_epoch):
- ema_epoch_checkpoint_path = self.config.ema_epoch_checkpoint_path(
- self._model_epoch
- )
- logging.info(f"Saving EMA epoch checkpoint to {ema_epoch_checkpoint_path}")
- with self._ema_context():
- self.save_checkpoint(ema_epoch_checkpoint_path)
- if self.config.validate_using_ema:
- best_checkpoint_context = self._ema_context
- else:
- best_checkpoint_context = contextlib.nullcontext # type: ignore
- with best_checkpoint_context():
- if valid_loss <= self._best_validation_loss:
- logging.info(
- "Saving lowest validation loss checkpoint to "
- f"{self.config.best_checkpoint_path}"
- )
- self._best_validation_loss = valid_loss
- self.save_checkpoint(self.config.best_checkpoint_path)
- if inference_error is not None and (
- inference_error <= self._best_inference_error
- ):
- logging.info(
- f"Epoch inference error ({inference_error}) is lower than "
- f"previous best inference error ({self._best_inference_error})."
- )
- logging.info(
- "Saving lowest inference error checkpoint to "
- f"{self.config.best_inference_checkpoint_path}"
- )
- self._best_inference_error = inference_error
- self.save_checkpoint(self.config.best_inference_checkpoint_path)
- with self._ema_context():
- logging.info(
- f"Saving latest EMA checkpoint to {self.config.ema_checkpoint_path}"
- )
- self.save_checkpoint(self.config.ema_checkpoint_path)
-def _restore_checkpoint(trainer: Trainer, checkpoint_path, ema_checkpoint_path):
- # separated into a function only to make it easier to mock
- checkpoint = torch.load(checkpoint_path, map_location=fme.get_device())
- # restore checkpoint is used for finetuning as well as resuming.
- # If finetuning (i.e., not resuming), restore checkpoint
- # does not load optimizer state, instead uses config specified lr.
- trainer.stepper.load_state(checkpoint["stepper"])
- trainer.optimization.load_state(checkpoint["optimization"])
- trainer.num_batches_seen = checkpoint["num_batches_seen"]
- trainer.startEpoch = checkpoint["epoch"]
- trainer._best_validation_loss = checkpoint["best_validation_loss"]
- trainer._best_inference_error = checkpoint["best_inference_error"]
- ema_checkpoint = torch.load(ema_checkpoint_path, map_location=fme.get_device())
- ema_stepper: SingleModuleStepper = SingleModuleStepper.from_state(
- ema_checkpoint["stepper"],
- area=trainer.train_data.area_weights,
- sigma_coordinates=trainer.train_data.sigma_coordinates,
- )
- trainer._ema = EMATracker.from_state(checkpoint["ema"], ema_stepper.modules)
+def run_train_from_config(config: TrainConfig):
+ run_train(TrainBuilders(config), config)
-def run_train_from_config(config: TrainConfig):
+def run_train(builders: TrainBuilders, config: TrainConfig):
dist = Distributed.get_instance()
if fme.using_gpu():
torch.backends.cudnn.benchmark = True
if not os.path.isdir(config.experiment_dir):
os.makedirs(config.experiment_dir, exist_ok=True)
- config.configure_logging(log_filename="out.log")
+ config.logging.configure_logging(config.experiment_dir, log_filename="out.log")
env_vars = logging_utils.retrieve_env_vars()
logging_utils.log_versions()
beaker_url = logging_utils.log_beaker_url()
- config.configure_wandb(env_vars=env_vars, resume=True, notes=beaker_url)
- trainer = Trainer(config)
+ config_as_dict = to_flat_dict(dataclasses.asdict(config))
+ config.logging.configure_wandb(
+ config=config_as_dict, env_vars=env_vars, resume=True, notes=beaker_url
+ )
+ trainer = build_trainer(builders, config)
trainer.train()
logging.info("DONE ---- rank %d" % dist.rank)
diff --git a/fme/fme/ace/train/train_config.py b/fme/fme/ace/train/train_config.py
index 1781b84..1d3e990 100644
--- a/fme/fme/ace/train/train_config.py
+++ b/fme/fme/ace/train/train_config.py
@@ -1,24 +1,37 @@
import dataclasses
-import logging
+import datetime
import os
-from typing import Any, Dict, Optional, Union
-
-from fme.core.aggregator import InferenceEvaluatorAggregatorConfig
-from fme.core.data_loading.config import DataLoaderConfig, Slice
-from fme.core.data_loading.inference import InferenceDataLoaderConfig
-from fme.core.dicts import to_flat_dict
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from fme.ace.aggregator import InferenceEvaluatorAggregatorConfig
+from fme.ace.data_loading.batch_data import GriddedData, InferenceGriddedData
+from fme.ace.data_loading.config import DataLoaderConfig
+from fme.ace.data_loading.getters import get_data_loader, get_inference_data
+from fme.ace.data_loading.inference import InferenceDataLoaderConfig
+from fme.ace.requirements import PrognosticStateDataRequirements
+from fme.ace.stepper import (
+ ExistingStepperConfig,
+ SingleModuleStepper,
+ SingleModuleStepperConfig,
+)
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.dataset.requirements import DataRequirements
from fme.core.distributed import Distributed
-from fme.core.ema import EMAConfig
+from fme.core.ema import EMAConfig, EMATracker
+from fme.core.generics.trainer import EndOfBatchCallback
+from fme.core.gridded_ops import GriddedOperations
from fme.core.logging_utils import LoggingConfig
-from fme.core.optimization import OptimizationConfig
-from fme.core.stepper import ExistingStepperConfig, SingleModuleStepperConfig
+from fme.core.optimization import Optimization, OptimizationConfig
+from fme.core.typing_ import Slice
from fme.core.weight_ops import CopyWeightsConfig
@dataclasses.dataclass
class InlineInferenceConfig:
"""
- Attributes:
+ Parameters:
loader: configuration for the data loader used during inference
n_forward_steps: number of forward steps to take
forward_steps_in_memory: number of forward steps to take before
@@ -34,7 +47,9 @@ class InlineInferenceConfig:
forward_steps_in_memory: int = 2
epochs: Slice = Slice(start=0, stop=None, step=1)
aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field(
- default_factory=lambda: InferenceEvaluatorAggregatorConfig()
+ default_factory=lambda: InferenceEvaluatorAggregatorConfig(
+ log_global_mean_time_series=False, log_global_mean_norm_time_series=False
+ )
)
def __post_init__(self):
@@ -46,6 +61,14 @@ def __post_init__(self):
f"{self.loader.start_indices.n_initial_conditions} and "
f"{dist.world_size}."
)
+ if (
+ self.aggregator.log_global_mean_time_series
+ or self.aggregator.log_global_mean_norm_time_series
+ ):
+ # Both of log_global_mean_time_series and
+ # log_global_mean_norm_time_series must be False for inline inference.
+ self.aggregator.log_global_mean_time_series = False
+ self.aggregator.log_global_mean_norm_time_series = False
@dataclasses.dataclass
@@ -53,7 +76,7 @@ class TrainConfig:
"""
Configuration for training a model.
- Attributes:
+ Arguments:
train_loader: Configuration for the training data loader.
validation_loader: Configuration for the validation data loader.
stepper: Configuration for the stepper.
@@ -103,6 +126,14 @@ class TrainConfig:
log_train_every_n_batches: int = 100
segment_epochs: Optional[int] = None
+ @property
+ def inference_n_forward_steps(self) -> int:
+ return self.inference.n_forward_steps
+
+ @property
+ def inference_aggregator(self) -> InferenceEvaluatorAggregatorConfig:
+ return self.inference.aggregator
+
@property
def checkpoint_dir(self) -> str:
"""
@@ -110,63 +141,83 @@ def checkpoint_dir(self) -> str:
"""
return os.path.join(self.experiment_dir, "training_checkpoints")
- @property
- def latest_checkpoint_path(self) -> str:
- return os.path.join(self.checkpoint_dir, "ckpt.tar")
+ def clean_wandb(self, experiment_dir: str) -> None:
+ self.logging.clean_wandb(experiment_dir=experiment_dir)
- @property
- def best_checkpoint_path(self) -> str:
- return os.path.join(self.checkpoint_dir, "best_ckpt.tar")
+ def get_inference_epochs(self) -> List[int]:
+ return list(range(0, self.max_epochs))[self.inference.epochs.slice]
- @property
- def best_inference_checkpoint_path(self) -> str:
- return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar")
- @property
- def ema_checkpoint_path(self) -> str:
- return os.path.join(self.checkpoint_dir, "ema_ckpt.tar")
+class TrainBuilders:
+ def __init__(self, config: TrainConfig):
+ self.config = config
- def epoch_checkpoint_path(self, epoch: int) -> str:
- return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar")
-
- def ema_epoch_checkpoint_path(self, epoch: int) -> str:
- return os.path.join(self.checkpoint_dir, f"ema_ckpt_{epoch:04d}.tar")
-
- def epoch_checkpoint_enabled(self, epoch: int) -> bool:
- return epoch_checkpoint_enabled(
- epoch, self.max_epochs, self.checkpoint_save_epochs
+ def _get_train_window_data_requirements(self) -> DataRequirements:
+ return self.config.stepper.get_evaluation_window_data_requirements(
+ self.config.n_forward_steps
)
- def ema_epoch_checkpoint_enabled(self, epoch: int) -> bool:
- return epoch_checkpoint_enabled(
- epoch, self.max_epochs, self.ema_checkpoint_save_epochs
+ def _get_evaluation_window_data_requirements(self) -> DataRequirements:
+ return self.config.stepper.get_evaluation_window_data_requirements(
+ self.config.inference.forward_steps_in_memory
)
- @property
- def resuming(self) -> bool:
- checkpoint_file_exists = os.path.isfile(self.latest_checkpoint_path)
- resuming = True if checkpoint_file_exists else False
- return resuming
-
- def configure_logging(self, log_filename: str):
- self.logging.configure_logging(self.experiment_dir, log_filename)
-
- def configure_wandb(self, env_vars: Optional[Dict[str, Any]] = None, **kwargs):
- config = to_flat_dict(dataclasses.asdict(self))
- self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs)
+ def _get_initial_condition_data_requirements(
+ self,
+ ) -> PrognosticStateDataRequirements:
+ return self.config.stepper.get_prognostic_state_data_requirements()
+
+ def get_train_data(self) -> GriddedData:
+ data_requirements = self._get_train_window_data_requirements()
+ return get_data_loader(
+ self.config.train_loader,
+ requirements=data_requirements,
+ train=True,
+ )
- def log(self):
- logging.info("------------------ Configuration ------------------")
- logging.info(str(self))
- logging.info("---------------------------------------------------")
+ def get_validation_data(self) -> GriddedData:
+ data_requirements = self._get_train_window_data_requirements()
+ return get_data_loader(
+ self.config.validation_loader,
+ requirements=data_requirements,
+ train=False,
+ )
- def clean_wandb(self):
- self.logging.clean_wandb(experiment_dir=self.experiment_dir)
+ def get_evaluation_inference_data(
+ self,
+ ) -> InferenceGriddedData:
+ return get_inference_data(
+ config=self.config.inference.loader,
+ total_forward_steps=self.config.inference_n_forward_steps,
+ window_requirements=self._get_evaluation_window_data_requirements(),
+ initial_condition=self._get_initial_condition_data_requirements(),
+ )
+ def get_optimization(self, modules: torch.nn.ModuleList) -> Optimization:
+ return self.config.optimization.build(modules, self.config.max_epochs)
+
+ def get_stepper(
+ self,
+ img_shape: Tuple[int, int],
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ) -> SingleModuleStepper:
+ return self.config.stepper.get_stepper(
+ img_shape=img_shape,
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ )
-def epoch_checkpoint_enabled(
- epoch: int, max_epochs: int, save_epochs: Optional[Slice]
-) -> bool:
- if save_epochs is None:
- return False
- return epoch in range(max_epochs)[save_epochs.slice]
+ def get_ema(self, modules) -> EMATracker:
+ return self.config.ema.build(modules)
+
+ def get_end_of_batch_ops(
+ self, modules: List[torch.nn.Module]
+ ) -> EndOfBatchCallback:
+ base_weights = self.config.stepper.get_base_weights()
+ if base_weights is not None:
+ copy_after_batch = self.config.copy_weights_after_batch
+ return lambda: copy_after_batch.apply(weights=base_weights, modules=modules)
+ return lambda: None
diff --git a/fme/fme/ace/validate_config.py b/fme/fme/ace/validate_config.py
index 0fab210..6fba79f 100644
--- a/fme/fme/ace/validate_config.py
+++ b/fme/fme/ace/validate_config.py
@@ -7,8 +7,8 @@
from fme.ace.inference.evaluator import InferenceEvaluatorConfig
from fme.ace.inference.inference import InferenceConfig
+from fme.ace.stepper import SingleModuleStepperConfig
from fme.ace.train.train_config import TrainConfig
-from fme.core.stepper import SingleModuleStepperConfig
CONFIG_CHOICES = ["train", "inference", "evaluator"]
diff --git a/fme/fme/core/__init__.py b/fme/fme/core/__init__.py
index 766547d..b7f028a 100644
--- a/fme/fme/core/__init__.py
+++ b/fme/fme/core/__init__.py
@@ -1,5 +1,6 @@
from .climate_data import ClimateData
from .device import get_device, using_gpu
+from .gridded_ops import GriddedOperations
from .metrics import (
root_mean_squared_error,
spherical_area_weights,
@@ -8,7 +9,6 @@
)
from .normalizer import StandardNormalizer, get_normalizer
from .packer import Packer
-from .stepper import SingleModuleStepper, SingleModuleStepperConfig
__all__ = [
"spherical_area_weights",
@@ -20,7 +20,6 @@
"StandardNormalizer",
"get_normalizer",
"Packer",
- "SingleModuleStepper",
- "SingleModuleStepperConfig",
"ClimateData",
+ "GriddedOperations",
]
diff --git a/fme/fme/core/aggregator/inference/main.py b/fme/fme/core/aggregator/inference/main.py
deleted file mode 100644
index badbd27..0000000
--- a/fme/fme/core/aggregator/inference/main.py
+++ /dev/null
@@ -1,552 +0,0 @@
-import dataclasses
-import datetime
-import warnings
-from typing import Dict, Iterable, List, Literal, Mapping, Optional, Protocol, Union
-
-import torch
-import xarray as xr
-
-from fme.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata
-from fme.core.typing_ import TensorMapping
-from fme.core.wandb import Table, WandB
-
-from ..one_step.reduced import MeanAggregator as OneStepMeanAggregator
-from .annual import GlobalMeanAnnualAggregator
-from .enso import EnsoCoefficientEvaluatorAggregator
-from .histogram import HistogramAggregator
-from .reduced import MeanAggregator, SingleTargetMeanAggregator
-from .seasonal import SeasonalAggregator
-from .spectrum import PairedSphericalPowerSpectrumAggregator
-from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator
-from .video import VideoAggregator
-from .zonal_mean import ZonalMeanAggregator
-
-wandb = WandB.get_instance()
-APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730)
-SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800)
-
-
-class _Aggregator(Protocol):
- @torch.no_grad()
- def record_batch(
- self,
- data: TensorMapping,
- ):
- ...
-
- @torch.no_grad()
- def get_logs(self, label: str):
- ...
-
- @torch.no_grad()
- def get_dataset(self) -> xr.Dataset:
- ...
-
-
-class _EvaluatorAggregator(Protocol):
- @torch.no_grad()
- def record_batch(
- self,
- loss: float,
- target_data: TensorMapping,
- gen_data: TensorMapping,
- target_data_norm: TensorMapping,
- gen_data_norm: TensorMapping,
- i_time_start: int = 0,
- ):
- ...
-
- @torch.no_grad()
- def get_logs(self, label: str):
- ...
-
- @torch.no_grad()
- def get_dataset(self) -> xr.Dataset:
- ...
-
-
-class _TimeDependentAggregator(Protocol):
- @torch.no_grad()
- def record_batch(
- self,
- time: xr.DataArray,
- data: TensorMapping,
- ):
- ...
-
- @torch.no_grad()
- def get_logs(self, label: str):
- ...
-
-
-class _TimeDependentEvaluatorAggregator(Protocol):
- @torch.no_grad()
- def record_batch(
- self,
- time: xr.DataArray,
- target_data: TensorMapping,
- gen_data: TensorMapping,
- ):
- ...
-
- @torch.no_grad()
- def get_logs(self, label: str):
- ...
-
-
-@dataclasses.dataclass
-class InferenceEvaluatorAggregatorConfig:
- """
- Configuration for inference evaluator aggregator.
-
- Attributes:
- log_histograms: Whether to log histograms of the targets and predictions.
- log_video: Whether to log videos of the state evolution.
- log_extended_video: Whether to log wandb videos of the predictions with
- statistical metrics, only done if log_video is True.
- log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a
- time dimension.
- log_seasonal_means: Whether to log seasonal mean metrics and images.
- monthly_reference_data: Path to monthly reference data to compare against.
- time_mean_reference_data: Path to reference time means to compare against.
- """
-
- log_histograms: bool = False
- log_video: bool = False
- log_extended_video: bool = False
- log_zonal_mean_images: bool = True
- log_seasonal_means: bool = False
- monthly_reference_data: Optional[str] = None
- time_mean_reference_data: Optional[str] = None
-
- def build(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- n_timesteps: int,
- initial_times: xr.DataArray,
- record_step_20: bool = False,
- data_grid: Literal[
- "legendre-gauss", "equiangular", "healpix"
- ] = "legendre-gauss",
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
- ) -> "InferenceEvaluatorAggregator":
- if self.monthly_reference_data is None:
- monthly_reference_data = None
- else:
- monthly_reference_data = xr.open_dataset(self.monthly_reference_data)
- if self.time_mean_reference_data is None:
- time_mean = None
- else:
- time_mean = xr.open_dataset(self.time_mean_reference_data)
-
- if n_timesteps > 2**15 and self.log_zonal_mean_images:
- # matplotlib raises an error if image size is too large, and we plot
- # one pixel per timestep in the zonal mean images.
- warnings.warn(
- "Disabling zonal mean images logging due to large number of timesteps"
- f" (n_timesteps={n_timesteps}). Set log_zonal_mean_images=False or "
- "decrease n_timesteps to below 2**15 to avoid this warning."
- )
- log_zonal_mean_images = False
- else:
- log_zonal_mean_images = self.log_zonal_mean_images
-
- return InferenceEvaluatorAggregator(
- area_weights=area_weights,
- sigma_coordinates=sigma_coordinates,
- timestep=timestep,
- n_timesteps=n_timesteps,
- initial_times=initial_times,
- log_histograms=self.log_histograms,
- log_video=self.log_video,
- enable_extended_videos=self.log_extended_video,
- log_zonal_mean_images=log_zonal_mean_images,
- log_seasonal_means=self.log_seasonal_means,
- monthly_reference_data=monthly_reference_data,
- time_mean_reference_data=time_mean,
- record_step_20=record_step_20,
- metadata=metadata,
- data_grid=data_grid,
- )
-
-
-class InferenceEvaluatorAggregator:
- """
- Aggregates statistics for inference comparing a generated and target series.
-
- To use, call `record_batch` on the results of each batch, then call
- `get_logs` to get a dictionary of statistics when you're done.
- """
-
- def __init__(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- n_timesteps: int,
- initial_times: xr.DataArray,
- record_step_20: bool = False,
- log_video: bool = False,
- enable_extended_videos: bool = False,
- log_zonal_mean_images: bool = False,
- log_seasonal_means: bool = False,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
- monthly_reference_data: Optional[xr.Dataset] = None,
- log_histograms: bool = False,
- data_grid: Literal[
- "legendre-gauss", "equiangular", "healpix"
- ] = "legendre-gauss",
- time_mean_reference_data: Optional[xr.Dataset] = None,
- ):
- """
- Args:
- area_weights: Area weights for each grid cell.
- sigma_coordinates: Data sigma coordinates
- timestep: Timestep of the model.
- n_timesteps: Number of timesteps of inference that will be run.
- initial_times: Initial times for each sample.
- record_step_20: Whether to record the mean of the 20th steps.
- log_video: Whether to log videos of the state evolution.
- enable_extended_videos: Whether to log videos of statistical
- metrics of state evolution
- log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a
- time dimension.
- log_seasonal_means: Whether to log seasonal means metrics and images.
- metadata: Mapping of variable names their metadata that will
- used in generating logged image captions.
- monthly_reference_data: Reference monthly data for computing target stats.
- log_histograms: Whether to aggregate histograms.
- data_grid: The grid type of the data, used for spherical power spectrum.
- time_mean_reference_data: Reference time means for computing bias stats.
- """
- self._aggregators: Dict[str, _EvaluatorAggregator] = {
- "mean": MeanAggregator(
- area_weights,
- target="denorm",
- n_timesteps=n_timesteps,
- metadata=metadata,
- ),
- "mean_norm": MeanAggregator(
- area_weights,
- target="norm",
- n_timesteps=n_timesteps,
- metadata=metadata,
- ),
- "time_mean": TimeMeanEvaluatorAggregator(
- area_weights,
- metadata=metadata,
- reference_means=time_mean_reference_data,
- ),
- "time_mean_norm": TimeMeanEvaluatorAggregator(
- area_weights,
- target="norm",
- metadata=metadata,
- ),
- }
- if len(area_weights.shape) == 2:
- self._aggregators[
- "spherical_power_spectrum"
- ] = PairedSphericalPowerSpectrumAggregator(
- area_weights.shape[-2],
- area_weights.shape[-1],
- data_grid,
- )
- else:
- warnings.warn(
- "Area weights are not 2D, spherical power spectrum will not be computed"
- )
- if record_step_20:
- self._aggregators["mean_step_20"] = OneStepMeanAggregator(
- area_weights, target_time=20
- )
- if log_video:
- self._aggregators["video"] = VideoAggregator(
- n_timesteps=n_timesteps,
- enable_extended_videos=enable_extended_videos,
- metadata=metadata,
- )
- if log_zonal_mean_images:
- self._aggregators["zonal_mean"] = ZonalMeanAggregator(
- n_timesteps=n_timesteps,
- metadata=metadata,
- )
- if log_histograms:
- self._aggregators["histogram"] = HistogramAggregator()
- self._time_dependent_aggregators: Dict[
- str, _TimeDependentEvaluatorAggregator
- ] = {}
- if log_seasonal_means:
- self._time_dependent_aggregators["seasonal"] = SeasonalAggregator(
- area_weights=area_weights,
- metadata=metadata,
- )
- if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS:
- self._time_dependent_aggregators["annual"] = GlobalMeanAnnualAggregator(
- area_weights=area_weights,
- timestep=timestep,
- metadata=metadata,
- monthly_reference_data=monthly_reference_data,
- )
- if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS:
- self._time_dependent_aggregators[
- "enso_coefficient"
- ] = EnsoCoefficientEvaluatorAggregator(
- initial_times,
- n_timesteps - 1,
- timestep,
- area_weights,
- metadata=metadata,
- )
-
- @torch.no_grad()
- def record_batch(
- self,
- loss: float,
- time: xr.DataArray,
- target_data: TensorMapping,
- gen_data: TensorMapping,
- target_data_norm: TensorMapping,
- gen_data_norm: TensorMapping,
- i_time_start: int = 0,
- ):
- if len(target_data) == 0:
- raise ValueError("No data in target_data")
- if len(gen_data) == 0:
- raise ValueError("No data in gen_data")
- target_data = {k: v for k, v in target_data.items() if k in gen_data}
- target_data_norm = {k: v for k, v in target_data_norm.items() if k in gen_data}
- for aggregator in self._aggregators.values():
- aggregator.record_batch(
- loss=loss,
- target_data=target_data,
- gen_data=gen_data,
- target_data_norm=target_data_norm,
- gen_data_norm=gen_data_norm,
- i_time_start=i_time_start,
- )
- for time_dependent_aggregator in self._time_dependent_aggregators.values():
- time_dependent_aggregator.record_batch(
- time=time,
- target_data=target_data,
- gen_data=gen_data,
- )
-
- @torch.no_grad()
- def get_logs(self, label: str):
- """
- Returns logs as can be reported to WandB.
-
- Args:
- label: Label to prepend to all log keys.
- """
- logs = {}
- for name, aggregator in self._aggregators.items():
- logs.update(aggregator.get_logs(label=name))
- for name, time_dependent_aggregator in self._time_dependent_aggregators.items():
- logs.update(time_dependent_aggregator.get_logs(label=name))
- logs = {f"{label}/{key}": val for key, val in logs.items()}
- return logs
-
- @torch.no_grad()
- def get_inference_logs(self, label: str) -> List[Dict[str, Union[float, int]]]:
- """
- Returns a list of logs to report to WandB.
-
- This is done because in inference, we use the wandb step
- as the time step, meaning we need to re-organize the logged data
- from tables into a list of dictionaries.
- """
- return to_inference_logs(self.get_logs(label=label))
-
- @torch.no_grad()
- def get_datasets(
- self, aggregator_whitelist: Optional[Iterable[str]] = None
- ) -> Dict[str, xr.Dataset]:
- """
- Args:
- aggregator_whitelist: aggregator names to include in the output. If
- None, return all the datasets associated with all aggregators.
- """
- datasets = (
- (name, agg.get_dataset()) for name, agg in self._aggregators.items()
- )
- if aggregator_whitelist is not None:
- filter_ = set(aggregator_whitelist)
- return {name: ds for name, ds in datasets if name in filter_}
-
- return {name: ds for name, ds in datasets}
-
-
-def to_inference_logs(
- log: Mapping[str, Union[Table, float, int]]
-) -> List[Dict[str, Union[float, int]]]:
- # we have a dictionary which contains WandB tables
- # which we will convert to a list of dictionaries, one for each
- # row in the tables. Any scalar values will be reported in the last
- # dictionary.
- n_rows = 0
- for val in log.values():
- if isinstance(val, Table):
- n_rows = max(n_rows, len(val.data))
- logs: List[Dict[str, Union[float, int]]] = []
- for i in range(n_rows):
- logs.append({})
- for key, val in log.items():
- if isinstance(val, Table):
- for i, row in enumerate(val.data):
- for j, col in enumerate(val.columns):
- key_without_table_name = key[: key.rfind("/")]
- logs[i][f"{key_without_table_name}/{col}"] = row[j]
- else:
- logs[-1][key] = val
- return logs
-
-
-def table_to_logs(table: Table) -> List[Dict[str, Union[float, int]]]:
- """
- Converts a WandB table into a list of dictionaries.
- """
- logs = []
- for row in table.data:
- logs.append({table.columns[i]: row[i] for i in range(len(row))})
- return logs
-
-
-@dataclasses.dataclass
-class InferenceAggregatorConfig:
- """
- Configuration for inference aggregator.
-
- Attributes:
- time_mean_reference_data: Path to reference time means to compare against.
- """
-
- time_mean_reference_data: Optional[str] = None
-
- def build(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- n_timesteps: int,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
- ) -> "InferenceAggregator":
- if self.time_mean_reference_data is not None:
- time_means = xr.open_dataset(self.time_mean_reference_data)
- else:
- time_means = None
- return InferenceAggregator(
- area_weights=area_weights,
- sigma_coordinates=sigma_coordinates,
- timestep=timestep,
- n_timesteps=n_timesteps,
- metadata=metadata,
- time_mean_reference_data=time_means,
- )
-
-
-class InferenceAggregator:
- """
- Aggregates statistics on a single timeseries of data.
-
- To use, call `record_batch` on the results of each batch, then call
- `get_logs` to get a dictionary of statistics when you're done.
- """
-
- def __init__(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- n_timesteps: int,
- metadata: Optional[Mapping[str, VariableMetadata]] = None,
- time_mean_reference_data: Optional[xr.Dataset] = None,
- ):
- """
- Args:
- area_weights: Area weights for each grid cell.
- sigma_coordinates: Data sigma coordinates
- timestep: Timestep of the model.
- metadata: Mapping of variable names their metadata that will
- used in generating logged image captions.
- time_mean_reference_data: Reference time means for computing bias stats.
- """
- self._aggregators: Dict[str, _Aggregator] = {
- "mean": SingleTargetMeanAggregator(
- area_weights,
- n_timesteps=n_timesteps,
- ),
- "time_mean": TimeMeanAggregator(
- area_weights,
- metadata=metadata,
- reference_means=time_mean_reference_data,
- ),
- }
- self._time_dependent_aggregators: Dict[str, _TimeDependentAggregator] = {}
-
- @torch.no_grad()
- def record_batch(
- self,
- time: xr.DataArray,
- data: TensorMapping,
- i_time_start: int,
- ):
- if len(data) == 0:
- raise ValueError("data is empty")
- for aggregator in self._aggregators.values():
- aggregator.record_batch(
- data=data,
- i_time_start=i_time_start,
- )
- for time_dependent_aggregator in self._time_dependent_aggregators.values():
- time_dependent_aggregator.record_batch(
- time=time,
- data=data,
- )
-
- @torch.no_grad()
- def get_logs(self, label: str):
- """
- Returns logs as can be reported to WandB.
-
- Args:
- label: Label to prepend to all log keys.
- """
- logs = {}
- for name, aggregator in self._aggregators.items():
- logs.update(aggregator.get_logs(label=name))
- for name, time_dependent_aggregator in self._time_dependent_aggregators.items():
- logs.update(time_dependent_aggregator.get_logs(label=name))
- logs = {f"{label}/{key}": val for key, val in logs.items()}
- return logs
-
- @torch.no_grad()
- def get_inference_logs(self, label: str) -> List[Dict[str, Union[float, int]]]:
- """
- Returns a list of logs to report to WandB.
-
- This is done because in inference, we use the wandb step
- as the time step, meaning we need to re-organize the logged data
- from tables into a list of dictionaries.
- """
- return to_inference_logs(self.get_logs(label=label))
-
- @torch.no_grad()
- def get_datasets(
- self, aggregator_whitelist: Optional[Iterable[str]] = None
- ) -> Dict[str, xr.Dataset]:
- """
- Args:
- aggregator_whitelist: aggregator names to include in the output. If
- None, return all the datasets associated with all aggregators.
- """
- datasets = (
- (name, agg.get_dataset()) for name, agg in self._aggregators.items()
- )
- if aggregator_whitelist is not None:
- filter_ = set(aggregator_whitelist)
- return {name: ds for name, ds in datasets if name in filter_}
-
- return {name: ds for name, ds in datasets}
diff --git a/fme/fme/core/aggregator/inference/test_evaluator.py b/fme/fme/core/aggregator/inference/test_evaluator.py
deleted file mode 100644
index c6edb0b..0000000
--- a/fme/fme/core/aggregator/inference/test_evaluator.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import datetime
-
-import numpy as np
-import pytest
-import torch
-import xarray as xr
-
-import fme
-from fme.core.aggregator.inference import InferenceEvaluatorAggregator
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.device import get_device
-
-TIMESTEP = datetime.timedelta(hours=6)
-
-
-def get_zero_time(shape, dims):
- return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims)
-
-
-def test_logs_labels_exist():
- n_sample = 10
- n_time = 22
- nx = 2
- ny = 2
- nz = 3
- loss = 1.0
- area_weights = torch.ones(ny).to(fme.get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- initial_times = get_zero_time(shape=[n_sample, 0], dims=["sample", "time"])
-
- agg = InferenceEvaluatorAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- n_time,
- initial_times,
- record_step_20=True,
- log_video=True,
- log_zonal_mean_images=True,
- )
- target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
- agg.record_batch(loss, time, target_data, gen_data, target_data_norm, gen_data_norm)
- logs = agg.get_logs(label="test")
- assert "test/mean/series" in logs
- assert "test/mean_norm/series" in logs
- assert "test/mean_step_20/weighted_rmse/a" in logs
- assert "test/mean_step_20/weighted_bias/a" in logs
- assert "test/mean_step_20/weighted_grad_mag_percent_diff/a" in logs
- table = logs["test/mean/series"]
- assert table.columns == [
- "forecast_step",
- "weighted_bias/a",
- "weighted_grad_mag_percent_diff/a",
- "weighted_mean_gen/a",
- "weighted_mean_target/a",
- "weighted_rmse/a",
- "weighted_std_gen/a",
- ]
- assert "test/time_mean/rmse/a" in logs
- assert "test/time_mean/bias/a" in logs
- assert "test/time_mean/bias_map/a" in logs
- assert "test/time_mean/gen_map/a" in logs
- assert "test/zonal_mean/error/a" in logs
- assert "test/zonal_mean/gen/a" in logs
- assert "test/video/a" in logs
-
-
-def test_inference_logs_labels_exist():
- n_sample = 10
- n_time = 22
- nx = 2
- ny = 2
- nz = 3
- loss = 1.0
- area_weights = torch.ones(ny).to(fme.get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- initial_times = (get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]),)
- agg = InferenceEvaluatorAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- n_time,
- initial_times,
- record_step_20=True,
- log_video=True,
- )
- target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
- agg.record_batch(loss, time, target_data, gen_data, target_data_norm, gen_data_norm)
- logs = agg.get_inference_logs(label="test")
- assert isinstance(logs, list)
- assert len(logs) == n_time
- assert "test/mean/weighted_bias/a" in logs[0]
- assert "test/mean/weighted_mean_gen/a" in logs[0]
- assert "test/mean/weighted_mean_target/a" in logs[0]
- assert "test/mean/weighted_grad_mag_percent_diff/a" in logs[0]
- assert "test/mean/weighted_rmse/a" in logs[0]
- assert "test/mean_norm/weighted_bias/a" in logs[0]
- assert "test/mean_norm/weighted_mean_gen/a" in logs[0]
- assert "test/mean_norm/weighted_mean_target/a" in logs[0]
- assert "test/mean_norm/weighted_rmse/a" in logs[0]
- # series/table data should be rolled out, not included as a table
- assert "test/mean/series" not in logs[0]
- assert "test/mean_norm/series" not in logs[0]
- assert "test/reduced/series" not in logs[0]
- assert "test/reduced_norm/series" not in logs[0]
-
-
-@pytest.mark.parametrize(
- "window_len, n_windows",
- [
- pytest.param(3, 1, id="single_window"),
- pytest.param(3, 2, id="two_windows"),
- ],
-)
-def test_i_time_start_gets_correct_time_longer_windows(window_len: int, n_windows: int):
- # while this directly tests the "mean" result, this is really a test that
- # the data from the correct timestep is piped into the aggregator.
- overlap = 1 # tested code assumes windows have one overlapping point
- area_weights = torch.ones(4).to(fme.get_device())
- nz = 3
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- initial_times = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),)
- agg = InferenceEvaluatorAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- (window_len - overlap) * n_windows + 1,
- initial_times,
- )
- target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())}
- time = get_zero_time(shape=[2, window_len], dims=["sample", "time"])
- i_start = 0
- for i in range(n_windows):
- sample_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())}
- for i in range(window_len):
- sample_data["a"][..., i, :, :] = float(i_start + i)
- agg.record_batch(
- 1.0,
- time=time,
- target_data=target_data,
- gen_data=sample_data,
- target_data_norm=target_data,
- gen_data_norm=sample_data,
- i_time_start=i_start,
- )
- i_start += window_len - overlap # subtract 1 for overlapping windows
- logs = agg.get_logs(label="metrics")
- table = logs["metrics/mean/series"]
- # get the weighted_bias column
- bias = table.get_column("weighted_bias/a")
- assert len(bias) == (window_len - overlap) * n_windows + overlap
- for i in range(len(bias)):
- np.testing.assert_allclose(bias[i], float(i), rtol=1e-5)
-
-
-@pytest.mark.parametrize(
- "window_len, n_windows, overlap",
- [
- pytest.param(3, 1, 0, id="single_window"),
- pytest.param(3, 2, 0, id="two_windows"),
- pytest.param(3, 2, 1, id="two_windows_overlap"),
- ],
-)
-def test_inference_logs_length(window_len: int, n_windows: int, overlap: int):
- """
- Test that the inference logs are the correct length when using one or more
- possibly-overlapping windows.
- """
- area_weights = torch.ones(4).to(fme.get_device())
- nz = 3
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- initial_times = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),)
- agg = InferenceEvaluatorAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- (window_len - overlap) * n_windows + overlap,
- initial_times,
- )
- target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())}
- time = get_zero_time(shape=[2, window_len], dims=["sample", "time"])
- i_start = 0
- for i in range(n_windows):
- sample_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())}
- for i in range(window_len):
- sample_data["a"][..., i, :, :] = float(i_start + i)
- agg.record_batch(
- 1.0,
- time=time,
- target_data=target_data,
- gen_data=sample_data,
- target_data_norm=target_data,
- gen_data_norm=sample_data,
- i_time_start=i_start,
- )
- i_start += window_len - overlap # subtract 1 for overlapping windows
- logs = agg.get_inference_logs(label="metrics")
- assert len(logs) == (window_len - overlap) * n_windows + overlap
diff --git a/fme/fme/core/aggregator/inference/test_inference.py b/fme/fme/core/aggregator/inference/test_inference.py
deleted file mode 100644
index d6781c0..0000000
--- a/fme/fme/core/aggregator/inference/test_inference.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import datetime
-
-import numpy as np
-import torch
-import xarray as xr
-
-import fme
-from fme.core.aggregator.inference import InferenceAggregator
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.device import get_device
-
-TIMESTEP = datetime.timedelta(hours=6)
-
-
-def get_zero_time(shape, dims):
- return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims)
-
-
-def test_logs_labels_exist():
- n_sample = 10
- n_time = 22
- nx = 2
- ny = 2
- nz = 3
- area_weights = torch.ones(ny).to(fme.get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- agg = InferenceAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- n_timesteps=n_time,
- )
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
- agg.record_batch(time, data=gen_data, i_time_start=0)
- logs = agg.get_logs(label="test")
- assert "test/mean/series" in logs
- assert "test/time_mean/gen_map/a" in logs
- assert "test/time_mean/ref_bias_map/a" not in logs
- assert "test/time_mean/ref_bias/a" not in logs
- assert "test/time_mean/ref_rmse/a" not in logs
-
-
-def test_logs_labels_exist_with_reference_time_means():
- n_sample = 10
- n_time = 22
- nx = 2
- ny = 2
- nz = 3
- area_weights = torch.ones(ny).to(fme.get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- reference_time_means = xr.Dataset(
- {
- "a": xr.DataArray(
- np.random.randn(ny, nx),
- dims=["grid_yt", "grid_xt"],
- )
- }
- )
- agg = InferenceAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- n_timesteps=n_time,
- time_mean_reference_data=reference_time_means,
- )
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
- agg.record_batch(time, data=gen_data, i_time_start=0)
- logs = agg.get_logs(label="test")
- assert "test/mean/series" in logs
- assert "test/time_mean/gen_map/a" in logs
- assert "test/time_mean/ref_bias_map/a" in logs
- assert "test/time_mean/ref_bias/a" in logs
- assert "test/time_mean/ref_rmse/a" in logs
-
-
-def test_inference_logs_labels_exist():
- n_sample = 10
- n_time = 22
- nx = 2
- ny = 2
- nz = 3
- area_weights = torch.ones(ny).to(fme.get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- agg = InferenceAggregator(
- area_weights,
- sigma_coordinates,
- TIMESTEP,
- n_timesteps=n_time,
- )
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"])
- agg.record_batch(time, data=gen_data, i_time_start=0)
- logs = agg.get_inference_logs(label="test")
- assert isinstance(logs, list)
- assert len(logs) == n_time
- assert "test/mean/weighted_mean_gen/a" in logs[0]
- assert "test/mean/weighted_mean_gen/a" in logs[-1]
- # assert len(logs) == n_time use this assertion when timeseries data is generated
- assert "test/time_mean/gen_map/a" in logs[-1]
diff --git a/fme/fme/core/aggregator/one_step/derived.py b/fme/fme/core/aggregator/one_step/derived.py
deleted file mode 100644
index 92de5a7..0000000
--- a/fme/fme/core/aggregator/one_step/derived.py
+++ /dev/null
@@ -1,135 +0,0 @@
-"""Derived metrics take the global state as input and usually output a new
-variable, e.g. dry air mass."""
-
-import abc
-from dataclasses import dataclass
-from typing import Dict, List, Mapping, Optional, Tuple
-
-import torch
-
-from fme.core.climate_data import (
- CLIMATE_FIELD_NAME_PREFIXES,
- ClimateData,
- compute_dry_air_absolute_differences,
-)
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.device import get_device
-from fme.core.typing_ import TensorMapping
-
-
-@dataclass
-class _TargetGenPair:
- target: torch.Tensor
- gen: torch.Tensor
-
-
-class DerivedMetric(abc.ABC):
- """Derived metrics are computed from the global state and usually output a
- new variable, e.g. dry air tendencies."""
-
- @abc.abstractmethod
- def record(self, target: ClimateData, gen: ClimateData) -> None:
- ...
-
- @abc.abstractmethod
- def get(self) -> _TargetGenPair:
- """Returns the derived metric applied to the target and data generated
- by the model."""
- ...
-
-
-class DryAir(DerivedMetric):
- """Computes absolute value of the dry air tendency of the first time step,
- averaged over the batch. If the data does not contain the required fields,
- then returns NaN."""
-
- def __init__(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- device: torch.device,
- spatial_dims=(2, 3),
- ):
- self._area_weights = area_weights
- self._sigma_coordinates = sigma_coordinates
- self._dry_air_target_total: Optional[torch.Tensor] = None
- self._dry_air_gen_total: Optional[torch.Tensor] = None
- self._device = device
- self._spatial_dims: Tuple[int, int] = spatial_dims
-
- def record(self, target: ClimateData, gen: ClimateData) -> None:
- def _compute_dry_air_helper(climate_data: ClimateData) -> torch.Tensor:
- return compute_dry_air_absolute_differences(
- climate_data,
- area=self._area_weights,
- sigma_coordinates=self._sigma_coordinates,
- )[0]
-
- dry_air_target = _compute_dry_air_helper(target)
- dry_air_gen = _compute_dry_air_helper(gen)
-
- # initialize
- if self._dry_air_target_total is None:
- self._dry_air_target_total = torch.zeros_like(
- dry_air_target, device=self._device
- )
- if self._dry_air_gen_total is None:
- self._dry_air_gen_total = torch.zeros_like(dry_air_gen, device=self._device)
-
- self._dry_air_target_total += dry_air_target
- self._dry_air_gen_total += dry_air_gen
-
- def get(self) -> _TargetGenPair:
- if self._dry_air_target_total is None or self._dry_air_gen_total is None:
- raise ValueError("No batches have been recorded.")
- return _TargetGenPair(
- target=self._dry_air_target_total, gen=self._dry_air_gen_total
- )
-
-
-class DerivedMetricsAggregator:
- def __init__(
- self,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- climate_field_name_prefixes: Mapping[
- str, List[str]
- ] = CLIMATE_FIELD_NAME_PREFIXES,
- ):
- self.area_weights = area_weights
- self.sigma_coordinates = sigma_coordinates
- self.climate_field_name_prefixes = climate_field_name_prefixes
- device = get_device()
- self._derived_metrics: Dict[str, DerivedMetric] = {
- "surface_pressure_due_to_dry_air": DryAir(
- self.area_weights, self.sigma_coordinates, device=device
- )
- }
- self._n_batches = 0
-
- @torch.no_grad()
- def record_batch(
- self,
- loss: float,
- target_data: TensorMapping,
- gen_data: TensorMapping,
- target_data_norm: TensorMapping,
- gen_data_norm: TensorMapping,
- ):
- del loss, target_data_norm, gen_data_norm # unused
- target = ClimateData(target_data, self.climate_field_name_prefixes)
- gen = ClimateData(gen_data, self.climate_field_name_prefixes)
-
- for metric_fn in self._derived_metrics.values():
- metric_fn.record(target, gen)
-
- # only increment n_batches if we actually recorded a batch
- self._n_batches += 1
-
- def get_logs(self, label: str):
- logs = dict()
- for metric_name in self._derived_metrics:
- values = self._derived_metrics[metric_name].get()
- logs[f"{label}/{metric_name}/target"] = values.target / self._n_batches
- logs[f"{label}/{metric_name}/gen"] = values.gen / self._n_batches
- return logs
diff --git a/fme/fme/core/aggregator/one_step/test_main.py b/fme/fme/core/aggregator/one_step/test_main.py
deleted file mode 100644
index 628b2fc..0000000
--- a/fme/fme/core/aggregator/one_step/test_main.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import pytest
-import torch
-
-from fme.core.aggregator.one_step import OneStepAggregator
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.device import get_device
-
-
-def test_labels_exist():
- n_sample = 10
- n_time = 3
- nx, ny, nz = 2, 2, 3
- loss = 1.0
- area_weights = torch.ones(ny).to(get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- agg = OneStepAggregator(area_weights, sigma_coordinates)
- target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}
- agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm)
- logs = agg.get_logs(label="test")
- assert "test/mean/loss" in logs
- assert "test/mean/weighted_rmse/a" in logs
- assert "test/mean/weighted_bias/a" in logs
- assert "test/mean/weighted_grad_mag_percent_diff/a" in logs
- assert "test/snapshot/image-full-field/a" in logs
- assert "test/snapshot/image-residual/a" in logs
- assert "test/snapshot/image-error/a" in logs
-
-
-def test_loss():
- """
- Basic test the aggregator combines loss correctly
- with multiple batches and no distributed training.
- """
- torch.manual_seed(0)
- example_data = {
- "a": torch.randn(1, 2, 5, 5, device=get_device()),
- }
- area_weights = torch.ones(1).to(get_device())
- nz = 3
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- aggregator = OneStepAggregator(area_weights, sigma_coordinates)
- aggregator.record_batch(
- loss=1.0,
- target_data=example_data,
- gen_data=example_data,
- target_data_norm=example_data,
- gen_data_norm=example_data,
- )
- aggregator.record_batch(
- loss=2.0,
- target_data=example_data,
- gen_data=example_data,
- target_data_norm=example_data,
- gen_data_norm=example_data,
- )
- logs = aggregator.get_logs(label="metrics")
- assert logs["metrics/mean/loss"] == 1.5
- aggregator.record_batch(
- loss=3.0,
- target_data=example_data,
- gen_data=example_data,
- target_data_norm=example_data,
- gen_data_norm=example_data,
- )
- logs = aggregator.get_logs(label="metrics")
- assert logs["metrics/mean/loss"] == 2.0
-
-
-def test_aggregator_raises_on_no_data():
- """
- Basic test the aggregator combines loss correctly
- with multiple batches and no distributed training.
- """
- ny, nz = 2, 3
- area_weights = torch.ones(ny).to(get_device())
- sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1))
- agg = OneStepAggregator(area_weights, sigma_coordinates)
- with pytest.raises(ValueError) as excinfo:
- agg.record_batch(
- loss=1.0, target_data={}, gen_data={}, target_data_norm={}, gen_data_norm={}
- )
- # check that the raised exception contains the right substring
- assert "No data" in str(excinfo.value)
-
-
-def test_derived():
- n_sample = 5
- n_time = 3
- nx, ny, nz = 2, 4, 3
- loss = 1.0
- area_weights = torch.ones(ny).to(get_device())
- sigma_coordinates = SigmaCoordinates(
- torch.arange(nz + 1).to(get_device()), torch.arange(nz + 1).to(get_device())
- )
- agg = OneStepAggregator(area_weights, sigma_coordinates)
-
- def _make_data():
- fields = ["a", "PRESsfc"] + [f"specific_total_water_{i}" for i in range(nz)]
- return {
- field: torch.randn(n_sample, n_time, nx, ny, device=get_device())
- for field in fields
- }
-
- target_data = _make_data()
- gen_data = _make_data()
- target_data_norm = _make_data()
- gen_data_norm = _make_data()
-
- agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm)
-
- logs = agg.get_logs("")
- target = logs["/derived/surface_pressure_due_to_dry_air/target"]
- gen = logs["/derived/surface_pressure_due_to_dry_air/gen"]
-
- assert target.shape == ()
- assert not torch.isnan(target).any()
- assert gen.shape == ()
- assert not torch.isnan(target).any()
-
-
-def test_derived_missing_surface_pressure():
- n_sample = 5
- n_time = 3
- nx, ny, nz = 2, 4, 3
- loss = 1.0
- area_weights = torch.ones(ny).to(get_device())
- sigma_coordinates = SigmaCoordinates(
- torch.arange(nz + 1).to(get_device()), torch.arange(nz + 1).to(get_device())
- )
- agg = OneStepAggregator(area_weights, sigma_coordinates)
-
- def _make_data():
- fields = ["a"] # N.B. no surface pressure or water fields.
- return {
- field: torch.randn(n_sample, n_time, nx, ny, device=get_device())
- for field in fields
- }
-
- target_data = _make_data()
- gen_data = _make_data()
- target_data_norm = _make_data()
- gen_data_norm = _make_data()
-
- agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm)
-
- logs = agg.get_logs("")
- target = logs["/derived/surface_pressure_due_to_dry_air/target"]
- gen = logs["/derived/surface_pressure_due_to_dry_air/gen"]
-
- assert target.shape == () and torch.isnan(target).all()
- assert gen.shape == () and torch.isnan(target).all()
-
-
-def test__get_loss_scaled_mse_components():
- x = torch.ones(10).to(get_device())
- loss_scaling = {
- "a": torch.tensor(1.0),
- "b": torch.tensor(0.5),
- }
- agg = OneStepAggregator(
- area_weights=torch.ones(10).to(get_device()),
- sigma_coordinates=SigmaCoordinates(x, x),
- loss_scaling=loss_scaling,
- )
-
- logs = {
- "test/mean/weighted_rmse/a": 1.0,
- "test/mean/weighted_rmse/b": 4.0,
- "test/mean/weighted_rmse/c": 0.0,
- }
- result = agg._get_loss_scaled_mse_components(logs, "test")
- scaled_squared_errors_sum = (1.0 / 1.0) ** 2 + (4.0 / 0.5) ** 2
- assert (
- result["test/mean/mse_fractional_components/a"] == 1 / scaled_squared_errors_sum
- )
- assert (
- result["test/mean/mse_fractional_components/b"]
- == 64 / scaled_squared_errors_sum
- )
- assert "test/mean/mse_fractional_components/c" not in result
diff --git a/fme/fme/core/aggregator/test_plotting.py b/fme/fme/core/aggregator/test_plotting.py
deleted file mode 100644
index 612b3cd..0000000
--- a/fme/fme/core/aggregator/test_plotting.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import numpy as np
-import pytest
-
-from .plotting import _stitch_data_panels, get_cmap_limits, plot_imshow
-
-
-def test_cmap_limits():
- data = np.array([1, 2, 3])
- vmin, vmax = get_cmap_limits(data)
- assert vmin == 1
- assert vmax == 3
-
-
-def test_cmap_limits_diverging():
- data = np.array([-1, 2, 3])
- vmin, vmax = get_cmap_limits(data, diverging=True)
- assert vmin == -3
- assert vmax == 3
-
-
-@pytest.mark.parametrize("use_colorbar", [True, False])
-def test_plot_imshow(use_colorbar):
- shape = [10, 15]
- data = np.random.randn(*shape)
- fig = plot_imshow(np.array(data), use_colorbar=use_colorbar)
- width, height = (fig.get_size_inches() * fig.dpi).astype(int)
- if use_colorbar:
- # colorbar is no more than 15% of the width but greater than 0 pixels
- assert shape[1] < width <= int(shape[1] * 1.15)
- assert height == shape[0]
- else:
- assert [height, width] == shape
-
-
-def test_stitch_data_panels():
- data = [
- [np.array([[1, 2]]), np.array([[3, 4]])],
- [np.array([[5, 6]]), np.array([[7, 8]])],
- ]
- stitched = _stitch_data_panels(data, vmin=1)
- expected = np.array(
- [ # vertical orientation is swapped as data starts from bottom-left
- [5, 6, 1, 7, 8],
- [1, 1, 1, 1, 1],
- [1, 2, 1, 3, 4],
- ]
- )
- assert np.array_equal(stitched, expected)
diff --git a/fme/fme/core/climate_data.py b/fme/fme/core/climate_data.py
index ebb5077..9891f05 100644
--- a/fme/fme/core/climate_data.py
+++ b/fme/fme/core/climate_data.py
@@ -1,6 +1,5 @@
-import re
from types import MappingProxyType
-from typing import List, Mapping, Union
+from typing import Callable, List, Mapping
import torch
@@ -12,7 +11,8 @@
RVGAS,
SPECIFIC_HEAT_OF_DRY_AIR_CONST_PRESSURE,
)
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.stacker import Stacker
from fme.core.typing_ import TensorDict, TensorMapping
CLIMATE_FIELD_NAME_PREFIXES = MappingProxyType(
@@ -20,47 +20,29 @@
"specific_total_water": ["specific_total_water_"],
"surface_pressure": ["PRESsfc", "PS"],
"surface_height": ["HGTsfc"],
+ "surface_geopotential": ["PHIS"],
"tendency_of_total_water_path_due_to_advection": [
"tendency_of_total_water_path_due_to_advection"
],
"latent_heat_flux": ["LHTFLsfc", "LHFLX"],
- "sensible_heat_flux": ["SHTFLsfc"],
+ "sensible_heat_flux": ["SHTFLsfc", "SHFLX"],
"precipitation_rate": ["PRATEsfc", "surface_precipitation_rate"],
- "sfc_down_sw_radiative_flux": ["DSWRFsfc"],
- "sfc_up_sw_radiative_flux": ["USWRFsfc"],
- "sfc_down_lw_radiative_flux": ["DLWRFsfc"],
- "sfc_up_lw_radiative_flux": ["ULWRFsfc"],
- "air_temperature": ["air_temperature_"],
+ "sfc_down_sw_radiative_flux": ["DSWRFsfc", "FSDS"],
+ "sfc_up_sw_radiative_flux": ["USWRFsfc", "surface_upward_shortwave_flux"],
+ "sfc_down_lw_radiative_flux": ["DLWRFsfc", "FLDS"],
+ "sfc_up_lw_radiative_flux": ["ULWRFsfc", "surface_upward_longwave_flux"],
+ "toa_up_lw_radiative_flux": ["ULWRFtoa", "FLUT"],
+ "toa_up_sw_radiative_flux": ["USWRFtoa", "top_of_atmos_upward_shortwave_flux"],
+ "toa_down_sw_radiative_flux": ["DSWRFtoa", "SOLIN"],
+ "air_temperature": ["air_temperature_", "T_"],
}
)
-def natural_sort(alist: List[str]) -> List[str]:
- """Sort to alphabetical order but with numbers sorted
- numerically, e.g. a11 comes after a2. See [1] and [2].
-
- [1] https://stackoverflow.com/questions/11150239/natural-sorting
- [2] https://en.wikipedia.org/wiki/Natural_sort_order
- """
-
- def convert(text: str) -> Union[str, int]:
- if text.isdigit():
- return int(text)
- else:
- return text.lower()
-
- def alphanum_key(item: str) -> List[Union[str, int]]:
- return [convert(c) for c in re.split("([0-9]+)", item)]
-
- return sorted(alist, key=alphanum_key)
-
-
-LEVEL_PATTERN = re.compile(r"_(\d+)$")
-
-
class ClimateData:
"""Container for climate data for accessing variables and providing
- torch.Tensor views on data with multiple vertical levels."""
+ torch.Tensor views on data with multiple vertical levels.
+ """
def __init__(
self,
@@ -74,87 +56,66 @@ def __init__(
Args:
climate_data: Mapping from field names to tensors.
- climate_field_name_prefixes: Mapping from field name prefixes (e.g.
- "specific_total_water_") to standardized prefixes, e.g. "PRESsfc" →
- "surface_pressure".
+ climate_field_name_prefixes: Mapping which defines the correspondence
+ between an arbitrary set of "standard" names (e.g., "surface_pressure"
+ or "air_temperature") and lists of possible names or prefix variants
+ (e.g., ["PRESsfc", "PS"] or ["air_temperature_", "T_"]) found in the
+ data.
"""
self._data = dict(climate_data)
- self._prefixes = climate_field_name_prefixes
-
- def _extract_levels(self, name: List[str]) -> torch.Tensor:
- for prefix in name:
- try:
- return self._extract_prefix_levels(prefix)
- except KeyError:
- pass
- raise KeyError(name)
-
- def _extract_prefix_levels(self, prefix: str) -> torch.Tensor:
- names = [
- field_name for field_name in self._data if field_name.startswith(prefix)
- ]
-
- levels = []
- for name in names:
- match = LEVEL_PATTERN.search(name)
- if match is None:
- raise ValueError(
- f"Invalid field name {name}, is a prefix variable "
- "but does not end in _(number)."
- )
- levels.append(int(match.group(1)))
+ self._prefix_map = climate_field_name_prefixes
+ self._stacker = Stacker(climate_field_name_prefixes)
- for i, level in enumerate(sorted(levels)):
- if i != level:
- raise KeyError(f"Missing level {i} in {prefix} levels {levels}.")
-
- if len(names) == 0:
- raise KeyError(prefix)
+ @property
+ def data(self) -> TensorDict:
+ """Mapping from field names to tensors."""
+ return self._data
- names = natural_sort(names)
- return torch.stack([self._data[name] for name in names], dim=-1)
-
- def _get(self, name):
- for prefix in self._prefixes[name]:
- if prefix in self._data.keys():
- return self._get_prefix(prefix)
- raise KeyError(name)
+ def __getitem__(self, name: str):
+ return getattr(self, name)
def _get_prefix(self, prefix):
- return self._data[prefix]
+ return self.data[prefix]
def _set(self, name, value):
- for prefix in self._prefixes[name]:
- if prefix in self._data.keys():
+ for prefix in self._prefix_map[name]:
+ if prefix in self.data.keys():
self._set_prefix(prefix, value)
return
raise KeyError(name)
def _set_prefix(self, prefix, value):
- self._data[prefix] = value
+ self.data[prefix] = value
- @property
- def data(self) -> TensorDict:
- """Mapping from field names to tensors."""
- return self._data
+ def _get(self, name):
+ for prefix in self._prefix_map[name]:
+ if prefix in self.data.keys():
+ return self._get_prefix(prefix)
+ raise KeyError(name)
@property
def air_temperature(self) -> torch.Tensor:
"""Returns all vertical levels of air_temperature, e.g. a tensor of
- shape `(..., vertical_level)`."""
- prefix = self._prefixes["air_temperature"]
- return self._extract_levels(prefix)
+ shape `(..., vertical_level)`.
+ """
+ return self._stacker("air_temperature", self.data)
@property
def specific_total_water(self) -> torch.Tensor:
"""Returns all vertical levels of specific total water, e.g. a tensor of
- shape `(..., vertical_level)`."""
- prefix = self._prefixes["specific_total_water"]
- return self._extract_levels(prefix)
+ shape `(..., vertical_level)`.
+ """
+ return self._stacker("specific_total_water", self.data)
@property
def surface_height(self) -> torch.Tensor:
- return self._get("surface_height")
+ try:
+ return self._get("surface_height")
+ except KeyError:
+ # E3SM saves geopotential not surface height so need to convert
+ # by using g value from e3sm
+ GRAVITY_E3SM = 9.80616
+ return self._get("surface_geopotential") / GRAVITY_E3SM
@property
def surface_pressure(self) -> torch.Tensor:
@@ -164,22 +125,45 @@ def surface_pressure(self) -> torch.Tensor:
def surface_pressure(self, value: torch.Tensor):
self._set("surface_pressure", value)
+ @property
+ def toa_down_sw_radiative_flux(self) -> torch.Tensor:
+ return self._get("toa_down_sw_radiative_flux")
+
+ @toa_down_sw_radiative_flux.setter
+ def toa_down_sw_radiative_flux(self, value: torch.Tensor):
+ self._set("toa_down_sw_radiative_flux", value)
+
+ @property
+ def toa_up_sw_radiative_flux(self) -> torch.Tensor:
+ return self._get("toa_up_sw_radiative_flux")
+
+ @toa_up_sw_radiative_flux.setter
+ def toa_up_sw_radiative_flux(self, value: torch.Tensor):
+ self._set("toa_up_sw_radiative_flux", value)
+
+ @property
+ def toa_up_lw_radiative_flux(self) -> torch.Tensor:
+ return self._get("toa_up_lw_radiative_flux")
+
+ @toa_up_lw_radiative_flux.setter
+ def toa_up_lw_radiative_flux(self, value: torch.Tensor):
+ self._set("toa_up_lw_radiative_flux", value)
+
def surface_pressure_due_to_dry_air(
- self, sigma_coordinates: SigmaCoordinates
+ self, vertical_coordinate: HybridSigmaPressureCoordinate
) -> torch.Tensor:
return metrics.surface_pressure_due_to_dry_air(
self.specific_total_water,
self.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
+ vertical_coordinate,
)
- def total_water_path(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor:
- return metrics.vertical_integral(
+ def total_water_path(
+ self, vertical_coordinate: HybridSigmaPressureCoordinate
+ ) -> torch.Tensor:
+ return vertical_coordinate.vertical_integral(
self.specific_total_water,
self.surface_pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
)
@property
@@ -240,23 +224,25 @@ def tendency_of_total_water_path_due_to_advection(self, value: torch.Tensor):
self._set("tendency_of_total_water_path_due_to_advection", value)
def height_at_log_midpoint(
- self, sigma_coordinates: SigmaCoordinates
+ self, vertical_coordinate: HybridSigmaPressureCoordinate
) -> torch.Tensor:
"""
- Compute vertical height at layer log midpoints
+ Compute vertical height at layer log midpoints.
"""
- pressure_interfaces = _pressure_at_interface(
- sigma_coordinates.ak, sigma_coordinates.bk, self.surface_pressure
+ interface_pressure = vertical_coordinate.interface_pressure(
+ self.surface_pressure
)
layer_thickness = _layer_thickness(
- pressure_at_interface=pressure_interfaces,
+ pressure_at_interface=interface_pressure,
air_temperature=self.air_temperature,
specific_total_water=self.specific_total_water,
)
height_at_interface = _height_at_interface(layer_thickness, self.surface_height)
return (height_at_interface[..., :-1] * height_at_interface[..., 1:]) ** 0.5
- def moist_static_energy(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor:
+ def moist_static_energy(
+ self, vertical_coordinate: HybridSigmaPressureCoordinate
+ ) -> torch.Tensor:
"""
Compute moist static energy.
"""
@@ -265,20 +251,22 @@ def moist_static_energy(self, sigma_coordinates: SigmaCoordinates) -> torch.Tens
return (
self.air_temperature * SPECIFIC_HEAT_OF_DRY_AIR_CONST_PRESSURE
+ self.specific_total_water * LATENT_HEAT_OF_VAPORIZATION
- + self.height_at_log_midpoint(sigma_coordinates) * GRAVITY
+ + self.height_at_log_midpoint(vertical_coordinate) * GRAVITY
)
def compute_dry_air_absolute_differences(
- climate_data: ClimateData, area: torch.Tensor, sigma_coordinates: SigmaCoordinates
+ climate_data: ClimateData,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ vertical_coordinate: HybridSigmaPressureCoordinate,
) -> torch.Tensor:
"""
Computes the absolute value of the dry air tendency of each time step.
Args:
climate_data: ClimateData object.
- area: Area of each grid cell as a [lat, lon] tensor, in m^2.
- sigma_coordinates: The sigma coordinates of the model.
+ area_weighted_mean: Function which returns an area-weighted mean.
+ vertical_coordinate: The vertical coordinate of the model.
Returns:
A tensor of shape (time,) of the absolute value of the dry air tendency
@@ -289,21 +277,11 @@ def compute_dry_air_absolute_differences(
pressure = climate_data.surface_pressure
except KeyError:
return torch.tensor([torch.nan])
- return (
- metrics.weighted_mean(
- metrics.surface_pressure_due_to_dry_air(
- water, # (sample, time, y, x, level)
- pressure,
- sigma_coordinates.ak,
- sigma_coordinates.bk,
- ),
- area,
- dim=(2, 3),
- )
- .diff(dim=-1)
- .abs()
- .mean(dim=0)
+ ps_dry = metrics.surface_pressure_due_to_dry_air(
+ water, pressure, vertical_coordinate
)
+ ps_dry_mean = area_weighted_mean(ps_dry)
+ return ps_dry_mean.diff(dim=-1).abs().mean(dim=0)
def _layer_thickness(
@@ -314,25 +292,14 @@ def _layer_thickness(
"""
Computes vertical thickness of each layer assuming hydrostatic equilibrium.
ACE does not currently prognose specific humidity, so here we closely
- approximate this using specific total water."""
+ approximate this using specific total water.
+ """
tv = air_temperature * (1 + (RVGAS / RDGAS - 1.0) * specific_total_water)
- dlogp = torch.log(pressure_at_interface).diff(dim=-1)
+ # Enforce min log(p) = 0 so that geopotential energy calculation is finite
+ dlogp = torch.clamp(torch.log(pressure_at_interface), min=0.0).diff(dim=-1)
return dlogp * RDGAS * tv / GRAVITY
-def _pressure_at_interface(
- ak: torch.tensor, bk: torch.tensor, surface_pressure: torch.tensor
-) -> torch.Tensor:
- """
- Computes pressure at layer interfaces from sigma coefficients.
- Vertical coordinate is the last tensor dimension.
- """
- return torch.stack(
- [ak[i] + bk[i] * surface_pressure for i in range(ak.shape[-1])],
- dim=-1,
- )
-
-
def _height_at_interface(
layer_thickness: torch.tensor, surface_height: torch.tensor
) -> torch.Tensor:
diff --git a/fme/fme/core/coordinates.py b/fme/fme/core/coordinates.py
new file mode 100644
index 0000000..8e059a9
--- /dev/null
+++ b/fme/fme/core/coordinates.py
@@ -0,0 +1,397 @@
+import abc
+import dataclasses
+from typing import List, Literal, Mapping, Optional, Tuple, TypeVar
+
+import numpy as np
+import torch
+from astropy_healpix import HEALPix
+
+from fme.core import metrics
+from fme.core.constants import GRAVITY
+from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
+from fme.core.typing_ import TensorMapping
+from fme.core.winds import lon_lat_to_xyz
+
+HC = TypeVar("HC", bound="HorizontalCoordinates")
+
+
+@dataclasses.dataclass
+class HybridSigmaPressureCoordinate:
+ """
+ Defines pressure at interface levels according to the following formula:
+ p(k) = a(k) + b(k)*ps.
+
+ where ps is the surface pressure, a and b are the sigma-pressure coordinates.
+
+ Parameters:
+ ak: a(k) coefficients as a 1-dimensional tensor
+ bk: b(k) coefficients as a 1-dimensional tensor
+ """
+
+ ak: torch.Tensor
+ bk: torch.Tensor
+
+ def __post_init__(self):
+ if len(self.ak.shape) != 1:
+ raise ValueError(
+ f"ak must be a 1-dimensional tensor. Got shape: {self.ak.shape}"
+ )
+ if len(self.bk.shape) != 1:
+ raise ValueError(
+ f"bk must be a 1-dimensional tensor. Got shape: {self.bk.shape}"
+ )
+ if len(self.ak) != len(self.bk):
+ raise ValueError(
+ f"ak and bk must have the same length. Got len(ak)={len(self.ak)} and "
+ f"len(bk)={len(self.bk)}."
+ )
+
+ def __len__(self):
+ """The number of vertical layer interfaces."""
+ return len(self.ak)
+
+ @property
+ def coords(self) -> Mapping[str, np.ndarray]:
+ return {"ak": self.ak.cpu().numpy(), "bk": self.bk.cpu().numpy()}
+
+ def to(self, device: str) -> "HybridSigmaPressureCoordinate":
+ return HybridSigmaPressureCoordinate(
+ ak=self.ak.to(device),
+ bk=self.bk.to(device),
+ )
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, HybridSigmaPressureCoordinate):
+ return False
+ return torch.allclose(self.ak, other.ak) and torch.allclose(self.bk, other.bk)
+
+ def as_dict(self) -> TensorMapping:
+ return {"ak": self.ak, "bk": self.bk}
+
+ def interface_pressure(self, surface_pressure: torch.Tensor) -> torch.Tensor:
+ """
+ Compute pressure at vertical layer interfaces.
+
+ Args:
+ surface_pressure: The surface pressure in units of Pa.
+
+ Returns:
+ A tensor of pressure at vertical layer interfaces. Will contain a new
+ dimension at the end, representing the vertical.
+ """
+ return torch.stack(
+ [ak + bk * surface_pressure for ak, bk in zip(self.ak, self.bk)],
+ dim=-1,
+ )
+
+ def vertical_integral(
+ self, integrand: torch.Tensor, surface_pressure: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Compute the mass-weighted vertical integral of the integrand.
+
+ (1 / g) * ∫ x dp
+
+ where
+ - g = acceleration due to gravity
+ - x = integrad
+ - p = pressure level
+
+ Args:
+ surface_pressure: The surface pressure in units of Pa.
+ integrand: A tensor whose last dimension is the vertical.
+
+ Returns:
+ A tensor of same shape as integrand but without the last dimension.
+ """
+ if len(self.ak) != integrand.shape[-1] + 1:
+ raise ValueError(
+ "The last dimension of integrand must match the number of vertical "
+ "layers in the hybrid sigma-pressure vertical coordinate."
+ )
+ interface_pressure = self.interface_pressure(surface_pressure)
+ pressure_thickness = interface_pressure.diff(dim=-1)
+ return (integrand * pressure_thickness).sum(dim=-1) / GRAVITY
+
+
+@dataclasses.dataclass
+class DimSize:
+ name: str
+ size: int
+
+
+class HorizontalCoordinates(abc.ABC):
+ """
+ Parent class for horizontal coordinate system grids.
+ Contains coords which must be subclassed to provide the coordinates.
+ """
+
+ @abc.abstractmethod
+ def __eq__(self, other) -> bool:
+ pass
+
+ @abc.abstractmethod
+ def to(self: HC, device: str) -> HC:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def coords(self) -> Mapping[str, np.ndarray]:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def xyz(self) -> Tuple[float, float, float]:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def dims(self) -> List[str]:
+ """Names of model horizontal dimensions."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def loaded_dims(self) -> List[str]:
+ """Names of horizontal dimensions as loaded from training dataset."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def loaded_sizes(self) -> List[DimSize]:
+ """Sizes of horizontal dimensions as loaded from training dataset."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def loaded_default_sizes(self) -> List[DimSize]:
+ """Default sizes of horizontal data dimensions, used by testing code."""
+ pass
+
+ @property
+ @abc.abstractmethod
+ def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]:
+ pass
+
+ # A temporary solution for training which allows us to aggregate along the
+ # latitude dimension.
+ # TODO: https://github.com/ai2cm/full-model/issues/1003
+ @abc.abstractmethod
+ def get_lat(self) -> torch.Tensor:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def area_weights(self) -> Optional[torch.Tensor]:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def gridded_operations(self) -> GriddedOperations:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Meshgrids of latitudes and longitudes, respectively."""
+ pass
+
+
+@dataclasses.dataclass
+class LatLonCoordinates(HorizontalCoordinates):
+ """
+ Defines a (latitude, longitude) grid.
+
+ Parameters:
+ lat: 1-dimensional tensor of latitudes
+ lon: 1-dimensional tensor of longitudes
+ loaded_lat_name: name of the latitude dimension
+ as loaded from training dataset
+ loaded_lon_name: name of the longitude dimension
+ as loaded from training dataset
+ """
+
+ lon: torch.Tensor
+ lat: torch.Tensor
+ loaded_lat_name: str = "lat"
+ loaded_lon_name: str = "lon"
+
+ def __post_init__(self):
+ self._area_weights: Optional[torch.Tensor] = None
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, LatLonCoordinates):
+ return False
+ return (
+ torch.allclose(self.lat, other.lat)
+ and torch.allclose(self.lon, other.lon)
+ and self.loaded_lat_name == other.loaded_lat_name
+ and self.loaded_lon_name == other.loaded_lon_name
+ )
+
+ def to(self, device: str) -> "LatLonCoordinates":
+ return LatLonCoordinates(
+ lon=self.lon.to(device),
+ lat=self.lat.to(device),
+ loaded_lat_name=self.loaded_lat_name,
+ loaded_lon_name=self.loaded_lon_name,
+ )
+
+ @property
+ def area_weights(self) -> torch.Tensor:
+ if self._area_weights is None:
+ self._area_weights = metrics.spherical_area_weights(self.lat, len(self.lon))
+ return self._area_weights
+
+ @property
+ def coords(self) -> Mapping[str, np.ndarray]:
+ # TODO: Replace with lat/lon name?
+ return {
+ "lat": self.lat.cpu().type(torch.float32).numpy(),
+ "lon": self.lon.cpu().type(torch.float32).numpy(),
+ }
+
+ @property
+ def xyz(self) -> Tuple[float, float, float]:
+ lats, lons = np.broadcast_arrays(
+ self.coords["lat"][:, None], self.coords["lon"][None, :]
+ )
+ return lon_lat_to_xyz(lons, lats)
+
+ def get_lat(self) -> torch.Tensor:
+ return self.lat
+
+ @property
+ def dims(self) -> List[str]:
+ return ["lat", "lon"]
+
+ @property
+ def loaded_dims(self) -> List[str]:
+ return [self.loaded_lat_name, self.loaded_lon_name]
+
+ @property
+ def loaded_sizes(self) -> List[DimSize]:
+ return [
+ DimSize(self.loaded_lat_name, len(self.lat)),
+ DimSize(self.loaded_lon_name, len(self.lon)),
+ ]
+
+ @property
+ def loaded_default_sizes(self) -> List[DimSize]:
+ return [DimSize(self.loaded_lat_name, 16), DimSize(self.loaded_lon_name, 32)]
+
+ @property
+ def grid(self) -> Literal["equiangular", "legendre-gauss"]:
+ if torch.allclose(
+ self.lat[1:] - self.lat[:-1],
+ self.lat[1] - self.lat[0],
+ ):
+ return "equiangular"
+ else:
+ return "legendre-gauss"
+
+ @property
+ def gridded_operations(self) -> LatLonOperations:
+ return LatLonOperations(self.area_weights)
+
+ @property
+ def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ return torch.meshgrid(self.lat, self.lon, indexing="ij")
+
+
+@dataclasses.dataclass
+class HEALPixCoordinates(HorizontalCoordinates):
+ """
+ Defines a HEALPix (face, height, width) grid. See https://healpix.jpl.nasa.gov/ for
+ more information.
+
+ Parameters:
+ face: 1-dimensional tensor of faces
+ height: 1-dimensional tensor of heights
+ width: 1-dimensional tensor of widths
+ """
+
+ face: torch.Tensor
+ height: torch.Tensor
+ width: torch.Tensor
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, HEALPixCoordinates):
+ return False
+ return (
+ torch.allclose(self.face, other.face)
+ and torch.allclose(self.height, other.height)
+ and torch.allclose(self.width, other.width)
+ )
+
+ def to(self, device: str) -> "HEALPixCoordinates":
+ return HEALPixCoordinates(
+ face=self.face.to(device),
+ height=self.height.to(device),
+ width=self.width.to(device),
+ )
+
+ @property
+ def coords(self) -> Mapping[str, np.ndarray]:
+ return {
+ "face": self.face.cpu().type(torch.float32).numpy(),
+ "height": self.height.cpu().type(torch.float32).numpy(),
+ "width": self.width.cpu().type(torch.float32).numpy(),
+ }
+
+ @property
+ def xyz(self) -> Tuple[float, float, float]:
+ hp = HEALPix(nside=len(self.height), order="ring")
+ return hp.healpix_to_xyz(
+ [self.coords["face"], self.coords["height"], self.coords["width"]]
+ )
+
+ @property
+ def dims(self) -> List[str]:
+ return ["face", "height", "width"]
+
+ @property
+ def loaded_dims(self) -> List[str]:
+ return self.dims
+
+ @property
+ def loaded_sizes(self) -> List[DimSize]:
+ return [
+ DimSize("face", len(self.face)),
+ DimSize("height", len(self.width)),
+ DimSize("width", len(self.height)),
+ ]
+
+ @property
+ def loaded_default_sizes(cls) -> List[DimSize]:
+ return [
+ DimSize("face", 12),
+ DimSize("height", 16),
+ DimSize("width", 16),
+ ]
+
+ # TODO: https://github.com/ai2cm/full-model/issues/1003
+ # This is currently the dummy solution.
+ def get_lat(self) -> torch.Tensor:
+ raise NotImplementedError(
+ "healpix does not support get_lat. If latitude is needed \
+ for some reason, you may use this class's self.xyz property to derive it."
+ )
+
+ @property
+ def grid(self) -> Literal["healpix"]:
+ return "healpix"
+
+ @property
+ def area_weights(self) -> Literal[None]:
+ return None
+
+ @property
+ def gridded_operations(self) -> HEALPixOperations:
+ return HEALPixOperations()
+
+ @property
+ def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError(
+ "meshgrid is not implemented yet for HEALPixCoordinates."
+ )
diff --git a/fme/fme/core/corrector/__init__.py b/fme/fme/core/corrector/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/fme/fme/core/corrector.py b/fme/fme/core/corrector/corrector.py
similarity index 67%
rename from fme/fme/core/corrector.py
rename to fme/fme/core/corrector/corrector.py
index 1c01ab1..bf3f72f 100644
--- a/fme/fme/core/corrector.py
+++ b/fme/fme/core/corrector/corrector.py
@@ -1,17 +1,22 @@
import dataclasses
import datetime
-from typing import List, Literal, Optional
+from typing import Any, Callable, List, Literal, Mapping, Optional, Protocol
+import dacite
import torch
-from fme.core import metrics
+import fme
from fme.core.climate_data import ClimateData
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping
+@CorrectorSelector.register("atmosphere_corrector")
@dataclasses.dataclass
-class CorrectorConfig:
+class CorrectorConfig(CorrectorConfigProtocol):
r"""
Configuration for the post-step state corrector.
@@ -57,7 +62,7 @@ class CorrectorConfig:
advection is zero. Therefore ``zero_global_mean_moisture_advection`` must be
True if using a ``moisture_budget_correction`` option other than ``None``.
- Attributes:
+ Parameters:
conserve_dry_air: If True, force the generated data to conserve dry air
by subtracting a constant offset from the surface pressure of each
column. This can cause changes in per-mass values such as total water
@@ -67,19 +72,21 @@ class CorrectorConfig:
offset from the moisture advection tendency of each column.
moisture_budget_correction: If not "None", force the generated data to
conserve global or column-local moisture by modifying budget fields.
- Options include:
- - "precipitation": multiply precipitation by a scale factor
- to close the global moisture budget.
- - "evaporation": multiply evaporation by a scale factor
- to close the global moisture budget.
- - "advection_and_precipitation": after applying the "precipitation"
- global-mean correction above, recompute the column-integrated
- advective tendency as the budget residual,
- ensuring column budget closure.
- - "advection_and_evaporation": after applying the "evaporation"
- global-mean correction above, recompute the column-integrated
- advective tendency as the budget residual,
- ensuring column budget closure.
+ Options are:
+
+ - ``precipitation``: multiply precipitation by a scale factor
+ to close the global moisture budget.
+ - ``evaporation``: multiply evaporation by a scale factor
+ to close the global moisture budget.
+ - ``advection_and_precipitation``: after applying the "precipitation"
+ global-mean correction above, recompute the column-integrated
+ advective tendency as the budget residual,
+ ensuring column budget closure.
+ - ``advection_and_evaporation``: after applying the "evaporation"
+ global-mean correction above, recompute the column-integrated
+ advective tendency as the budget residual,
+ ensuring column budget closure.
+
force_positive_names: Names of fields that should be forced to be greater
than or equal to zero. This is useful for fields like precipitation.
"""
@@ -98,65 +105,87 @@ class CorrectorConfig:
def build(
self,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
- ) -> Optional["Corrector"]:
+ ) -> "Corrector":
return Corrector(
config=self,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
timestep=timestep,
)
+ @classmethod
+ def from_state(cls, state: Mapping[str, Any]) -> "CorrectorConfig":
+ return dacite.from_dict(
+ data_class=cls, data=state, config=dacite.Config(strict=True)
+ )
+
-class Corrector:
+class Corrector(CorrectorABC):
def __init__(
self,
config: CorrectorConfig,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
self._config = config
- self._area = area
- self._sigma_coordinates = sigma_coordinates
+ self._gridded_operations = gridded_operations
+ self._vertical_coordinates = vertical_coordinate
self._timestep = timestep
+ if fme.get_device() == torch.device("mps", 0):
+ self._dry_air_precision = torch.float32
+ else:
+ self._dry_air_precision = torch.float64
def __call__(
self,
input_data: TensorMapping,
gen_data: TensorMapping,
- ):
+ forcing_data: TensorMapping,
+ ) -> TensorMapping:
+ """Apply corrections to the generated data.
+
+ Args:
+ input_data: The input time step data.
+ gen_data: The data generated by the model, to be corrected.
+ forcing_data: The forcing data for the same time step as gen_data.
+
+ Returns:
+ The corrected data.
+ """
if len(self._config.force_positive_names) > 0:
# do this step before imposing other conservation correctors, since
# otherwise it could end up creating violations of those constraints.
- gen_data = _force_positive(gen_data, self._config.force_positive_names)
+ gen_data = force_positive(gen_data, self._config.force_positive_names)
if self._config.conserve_dry_air:
gen_data = _force_conserve_dry_air(
input_data=input_data,
gen_data=gen_data,
- area=self._area,
- sigma_coordinates=self._sigma_coordinates,
+ area_weighted_mean=self._gridded_operations.area_weighted_mean,
+ vertical_coordinate=self._vertical_coordinates,
+ precision=self._dry_air_precision,
)
if self._config.zero_global_mean_moisture_advection:
gen_data = _force_zero_global_mean_moisture_advection(
gen_data=gen_data,
- area=self._area,
+ area_weighted_mean=self._gridded_operations.area_weighted_mean,
)
if self._config.moisture_budget_correction is not None:
gen_data = _force_conserve_moisture(
input_data=input_data,
gen_data=gen_data,
- area=self._area,
- sigma_coordinates=self._sigma_coordinates,
+ area_weighted_mean=self._gridded_operations.area_weighted_mean,
+ vertical_coordinate=self._vertical_coordinates,
timestep=self._timestep,
terms_to_modify=self._config.moisture_budget_correction,
)
return gen_data
-def _force_positive(data: TensorMapping, names: List[str]) -> TensorDict:
+def force_positive(data: TensorMapping, names: List[str]) -> TensorDict:
"""Clamp all tensors defined by `names` to be greater than or equal to zero."""
out = {**data}
for name in names:
@@ -164,11 +193,16 @@ def _force_positive(data: TensorMapping, names: List[str]) -> TensorDict:
return out
+class AreaWeightedMean(Protocol):
+ def __call__(self, data: torch.Tensor, keepdim: bool) -> torch.Tensor: ...
+
+
def _force_conserve_dry_air(
input_data: TensorMapping,
gen_data: TensorMapping,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
+ area_weighted_mean: AreaWeightedMean,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ precision: torch.dtype = torch.float64,
) -> TensorDict:
"""
Update the generated data to conserve dry air.
@@ -199,21 +233,20 @@ def _force_conserve_dry_air(
if input.surface_pressure is None:
raise ValueError("surface_pressure is required to force dry air conservation")
gen = ClimateData(gen_data)
- gen_dry_air = gen.surface_pressure_due_to_dry_air(sigma_coordinates)
- global_gen_dry_air = metrics.weighted_mean(gen_dry_air, weights=area, dim=(-2, -1))
- global_target_gen_dry_air = metrics.weighted_mean(
- input.surface_pressure_due_to_dry_air(sigma_coordinates),
- weights=area,
- dim=(-2, -1),
+ gen_dry_air = gen.surface_pressure_due_to_dry_air(vertical_coordinate)
+ global_gen_dry_air = area_weighted_mean(gen_dry_air.to(precision), keepdim=True)
+ global_target_gen_dry_air = area_weighted_mean(
+ input.surface_pressure_due_to_dry_air(vertical_coordinate).to(precision),
+ keepdim=True,
)
error = global_gen_dry_air - global_target_gen_dry_air
- new_gen_dry_air = gen_dry_air - error[..., None, None]
+ new_gen_dry_air = gen_dry_air.to(precision) - error
try:
- wat = gen.specific_total_water
+ wat = gen.specific_total_water.to(precision)
except KeyError:
raise ValueError("specific_total_water is required for conservation")
- ak_diff = sigma_coordinates.ak.diff()
- bk_diff = sigma_coordinates.bk.diff()
+ ak_diff = vertical_coordinate.ak.diff().to(precision)
+ bk_diff = vertical_coordinate.bk.diff().to(precision)
new_pressure = (new_gen_dry_air + (ak_diff * wat).sum(-1)) / (
1 - (bk_diff * wat).sum(-1)
)
@@ -223,7 +256,7 @@ def _force_conserve_dry_air(
def _force_zero_global_mean_moisture_advection(
gen_data: TensorMapping,
- area: torch.Tensor,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
) -> TensorDict:
"""
Update the generated data so advection conserves moisture.
@@ -232,15 +265,13 @@ def _force_zero_global_mean_moisture_advection(
Args:
gen_data: The generated data.
- area: (n_lat, n_lon) array containing relative gridcell area, in any
- units including unitless.
+ area_weighted_mean: Computes an area-weighted mean,
+ removing horizontal dimensions.
"""
gen = ClimateData(gen_data)
- mean_moisture_advection = metrics.weighted_mean(
+ mean_moisture_advection = area_weighted_mean(
gen.tendency_of_total_water_path_due_to_advection,
- weights=area,
- dim=(-2, -1),
)
gen.tendency_of_total_water_path_due_to_advection = (
gen.tendency_of_total_water_path_due_to_advection
@@ -252,8 +283,8 @@ def _force_zero_global_mean_moisture_advection(
def _force_conserve_moisture(
input_data: TensorMapping,
gen_data: TensorMapping,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
+ area_weighted_mean: AreaWeightedMean,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
terms_to_modify: Literal[
"precipitation",
@@ -274,9 +305,9 @@ def _force_conserve_moisture(
Args:
input_data: The input data.
gen_data: The generated data one timestep after the input data.
- area: (n_lat, n_lon) array containing relative gridcell area, in any
- units including unitless.
- sigma_coordinates: The sigma coordinates.
+ area_weighted_mean: Computes an area-weighted mean,
+ removing horizontal dimensions.
+ vertical_coordinate: The sigma coordinates.
timestep: Timestep of the model.
terms_to_modify: Which terms to modify, in addition to modifying surface
pressure to conserve dry air mass. One of:
@@ -288,20 +319,14 @@ def _force_conserve_moisture(
input = ClimateData(input_data)
gen = ClimateData(gen_data)
- gen_total_water_path = gen.total_water_path(sigma_coordinates)
+ gen_total_water_path = gen.total_water_path(vertical_coordinate)
timestep_seconds = timestep / datetime.timedelta(seconds=1)
twp_total_tendency = (
- gen_total_water_path - input.total_water_path(sigma_coordinates)
+ gen_total_water_path - input.total_water_path(vertical_coordinate)
) / timestep_seconds
- twp_tendency_global_mean = metrics.weighted_mean(
- twp_total_tendency, weights=area, dim=(-2, -1)
- )
- evaporation_global_mean = metrics.weighted_mean(
- gen.evaporation_rate, weights=area, dim=(-2, -1)
- )
- precipitation_global_mean = metrics.weighted_mean(
- gen.precipitation_rate, weights=area, dim=(-2, -1)
- )
+ twp_tendency_global_mean = area_weighted_mean(twp_total_tendency, keepdim=True)
+ evaporation_global_mean = area_weighted_mean(gen.evaporation_rate, keepdim=True)
+ precipitation_global_mean = area_weighted_mean(gen.precipitation_rate, keepdim=True)
if terms_to_modify.endswith("precipitation"):
# We want to achieve
# global_mean(twp_total_tendency) = (
@@ -324,20 +349,16 @@ def _force_conserve_moisture(
# new_precip_rate = (
# new_global_precip_rate / current_global_precip_rate
# ) * current_precip_rate
- gen.precipitation_rate = (
- gen.precipitation_rate
- * (new_precipitation_global_mean / precipitation_global_mean)[
- ..., None, None
- ]
+ gen.precipitation_rate = gen.precipitation_rate * (
+ new_precipitation_global_mean / precipitation_global_mean
)
elif terms_to_modify.endswith("evaporation"):
# Derived similarly as for "precipitation" case.
new_evaporation_global_mean = (
twp_tendency_global_mean + precipitation_global_mean
)
- gen.evaporation_rate = (
- gen.evaporation_rate
- * (new_evaporation_global_mean / evaporation_global_mean)[..., None, None]
+ gen.evaporation_rate = gen.evaporation_rate * (
+ new_evaporation_global_mean / evaporation_global_mean
)
if terms_to_modify.startswith("advection"):
# Having already corrected the global-mean budget, we recompute
diff --git a/fme/fme/core/corrector/ocean.py b/fme/fme/core/corrector/ocean.py
new file mode 100644
index 0000000..b890090
--- /dev/null
+++ b/fme/fme/core/corrector/ocean.py
@@ -0,0 +1,83 @@
+import dataclasses
+import datetime
+from types import MappingProxyType
+from typing import Any, List, Mapping, Optional
+
+import dacite
+
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.corrector.corrector import force_positive
+from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.masking import MaskingConfig
+from fme.core.registry.corrector import CorrectorSelector
+from fme.core.stacker import Stacker
+from fme.core.typing_ import TensorMapping
+
+OCEAN_FIELD_NAME_PREFIXES = MappingProxyType(
+ {
+ "surface_height": ["zos"],
+ "salinity": ["so_"],
+ "potential_temperature": ["thetao_"],
+ "zonal_velocity": ["uo_"],
+ "meridional_velocity": ["vo_"],
+ }
+)
+
+
+@CorrectorSelector.register("ocean_corrector")
+@dataclasses.dataclass
+class OceanCorrectorConfig(CorrectorConfigProtocol):
+ masking: Optional[MaskingConfig] = None
+ force_positive_names: List[str] = dataclasses.field(default_factory=list)
+
+ def build(
+ self,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ):
+ return OceanCorrector(
+ config=self,
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ )
+
+ @classmethod
+ def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig":
+ return dacite.from_dict(
+ data_class=cls, data=state, config=dacite.Config(strict=True)
+ )
+
+
+class OceanCorrector(CorrectorABC):
+ def __init__(
+ self,
+ config: OceanCorrectorConfig,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ):
+ self._config = config
+ self._gridded_operations = gridded_operations
+ self._vertical_coordinates = vertical_coordinate
+ self._timestep = timestep
+
+ if config.masking is not None:
+ self._masking = config.masking.build()
+ else:
+ self._masking = None
+ self._stacker = Stacker(OCEAN_FIELD_NAME_PREFIXES)
+
+ def __call__(
+ self,
+ input_data: TensorMapping,
+ gen_data: TensorMapping,
+ forcing_data: TensorMapping,
+ ) -> TensorMapping:
+ if self._masking is not None:
+ gen_data = self._masking(self._stacker, gen_data, input_data)
+ if len(self._config.force_positive_names) > 0:
+ gen_data = force_positive(gen_data, self._config.force_positive_names)
+ return gen_data
diff --git a/fme/fme/core/corrector/registry.py b/fme/fme/core/corrector/registry.py
new file mode 100644
index 0000000..e0efef1
--- /dev/null
+++ b/fme/fme/core/corrector/registry.py
@@ -0,0 +1,34 @@
+import abc
+import datetime
+from typing import Any, Mapping, Protocol
+
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.gridded_ops import GriddedOperations
+from fme.core.typing_ import TensorMapping
+
+
+class CorrectorConfigProtocol(Protocol):
+ def build(
+ self,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ) -> "CorrectorABC": ...
+
+ @classmethod
+ def from_state(cls, state: Mapping[str, Any]) -> "CorrectorConfigProtocol":
+ """
+ Create a ModuleSelector from a dictionary containing all the information
+ needed to build a ModuleConfig.
+ """
+ ...
+
+
+class CorrectorABC(abc.ABC):
+ @abc.abstractmethod
+ def __call__(
+ self,
+ input_data: TensorMapping,
+ gen_data: TensorMapping,
+ forcing_data: TensorMapping,
+ ) -> TensorMapping: ...
diff --git a/fme/fme/core/corrector/test_corrector.py b/fme/fme/core/corrector/test_corrector.py
new file mode 100644
index 0000000..25b67f0
--- /dev/null
+++ b/fme/fme/core/corrector/test_corrector.py
@@ -0,0 +1,260 @@
+import datetime
+from typing import Callable, Optional, Tuple
+
+import numpy as np
+import pytest
+import torch
+
+from fme.ace.inference.derived_variables import total_water_path_budget_residual
+from fme.core import ClimateData, metrics
+from fme.core.climate_data import compute_dry_air_absolute_differences
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.corrector.ocean import OceanCorrector
+from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
+from fme.core.registry.corrector import CorrectorSelector
+from fme.core.typing_ import TensorMapping
+
+from .corrector import (
+ _force_conserve_dry_air,
+ _force_conserve_moisture,
+ _force_zero_global_mean_moisture_advection,
+ force_positive,
+)
+
+TIMESTEP = datetime.timedelta(hours=6)
+
+
+def get_dry_air_nonconservation(
+ data: TensorMapping,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+):
+ """
+ Computes the time-average one-step absolute difference in surface pressure due to
+ changes in globally integrated dry air.
+
+ Args:
+ data: A mapping from variable name to tensor of shape
+ [sample, time, lat, lon], in physical units. specific_total_water in kg/kg
+ and surface_pressure in Pa must be present.
+ area_weighted_mean: Computes the area-weighted mean of a tensor, removing the
+ horizontal dimensions.
+ vertical_coordinate: The vertical coordinates of the model.
+ """
+ return compute_dry_air_absolute_differences(
+ ClimateData(data),
+ area_weighted_mean=area_weighted_mean,
+ vertical_coordinate=vertical_coordinate,
+ ).mean()
+
+
+def test_force_no_global_mean_moisture_advection():
+ torch.random.manual_seed(0)
+ data = {
+ "tendency_of_total_water_path_due_to_advection": torch.rand(size=(3, 2, 5, 5)),
+ }
+ area_weights = 1.0 + torch.rand(size=(5, 5))
+ original_mean = metrics.weighted_mean(
+ data["tendency_of_total_water_path_due_to_advection"],
+ weights=area_weights,
+ dim=[-2, -1],
+ )
+ assert (original_mean.abs() > 0.0).all()
+ fixed_data = _force_zero_global_mean_moisture_advection(
+ data,
+ area_weighted_mean=LatLonOperations(area_weights).area_weighted_mean,
+ )
+ new_mean = metrics.weighted_mean(
+ fixed_data["tendency_of_total_water_path_due_to_advection"],
+ weights=area_weights,
+ dim=[-2, -1],
+ )
+ assert (new_mean.abs() < original_mean.abs()).all()
+ np.testing.assert_almost_equal(new_mean.cpu().numpy(), 0.0, decimal=6)
+
+
+@pytest.mark.parametrize(
+ "size, use_area",
+ [
+ pytest.param((3, 2, 5, 5), True, id="latlon"),
+ pytest.param((3, 12, 2, 3, 3), False, id="healpix"),
+ ],
+)
+def test_force_conserve_dry_air(size: Tuple[int, ...], use_area: bool):
+ torch.random.manual_seed(0)
+ data = {
+ "PRESsfc": 10.0 + torch.rand(size=size),
+ "specific_total_water_0": torch.rand(size=size),
+ "specific_total_water_1": torch.rand(size=size),
+ }
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0])
+ )
+ if use_area:
+ area_weights: Optional[torch.Tensor] = 1.0 + torch.rand(
+ size=(size[-2], size[-1])
+ )
+ else:
+ area_weights = None
+ if area_weights is not None:
+ gridded_operations: GriddedOperations = LatLonOperations(area_weights)
+ else:
+ gridded_operations = HEALPixOperations()
+ original_nonconservation = get_dry_air_nonconservation(
+ data,
+ vertical_coordinate=vertical_coordinate,
+ area_weighted_mean=gridded_operations.area_weighted_mean,
+ )
+ assert original_nonconservation > 0.0
+ in_data = {k: v.select(dim=1, index=0) for k, v in data.items()}
+ out_data = {k: v.select(dim=1, index=1) for k, v in data.items()}
+ fixed_out_data = _force_conserve_dry_air(
+ in_data,
+ out_data,
+ vertical_coordinate=vertical_coordinate,
+ area_weighted_mean=gridded_operations.area_weighted_mean,
+ )
+ new_data = {
+ k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items()
+ }
+ new_nonconservation = get_dry_air_nonconservation(
+ new_data,
+ vertical_coordinate=vertical_coordinate,
+ area_weighted_mean=gridded_operations.area_weighted_mean,
+ )
+ assert new_nonconservation < original_nonconservation
+ np.testing.assert_almost_equal(new_nonconservation.cpu().numpy(), 0.0, decimal=6)
+
+
+@pytest.mark.parametrize("dataset", ["fv3", "e3sm"])
+@pytest.mark.parametrize(
+ "global_only, terms_to_modify",
+ [
+ (True, "precipitation"),
+ (True, "evaporation"),
+ (False, "advection_and_precipitation"),
+ (False, "advection_and_evaporation"),
+ ],
+)
+@pytest.mark.parametrize(
+ "size, use_area",
+ [
+ pytest.param((3, 2, 5, 5), True, id="latlon"),
+ pytest.param((3, 12, 2, 3, 3), False, id="healpix"),
+ ],
+)
+def test_force_conserve_moisture(
+ dataset: str,
+ global_only: bool,
+ terms_to_modify,
+ size: Tuple[int, ...],
+ use_area: bool,
+):
+ torch.random.manual_seed(0)
+ if dataset == "fv3":
+ data = {
+ "PRESsfc": 10.0 + torch.rand(size=size),
+ "specific_total_water_0": torch.rand(size=size),
+ "specific_total_water_1": torch.rand(size=size),
+ "PRATEsfc": torch.rand(size=size),
+ "LHTFLsfc": torch.rand(size=size),
+ "tendency_of_total_water_path_due_to_advection": torch.rand(size=size),
+ }
+ if dataset == "e3sm":
+ data = {
+ "PS": 10.0 + torch.rand(size=size),
+ "specific_total_water_0": torch.rand(size=size),
+ "specific_total_water_1": torch.rand(size=size),
+ "surface_precipitation_rate": torch.rand(size=size),
+ "LHFLX": torch.rand(size=size),
+ "tendency_of_total_water_path_due_to_advection": torch.rand(size=size),
+ }
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0])
+ )
+ if use_area:
+ ops: GriddedOperations = LatLonOperations(1.0 + torch.rand(size=(5, 5)))
+ else:
+ ops = HEALPixOperations()
+ data["tendency_of_total_water_path_due_to_advection"] -= ops.area_weighted_mean(
+ data["tendency_of_total_water_path_due_to_advection"], keepdim=True
+ )
+ original_budget_residual = total_water_path_budget_residual(
+ ClimateData(data),
+ vertical_coordinate=vertical_coordinate,
+ timestep=TIMESTEP,
+ )[:, 1] # no meaning for initial value data, want first timestep
+ if global_only:
+ original_budget_residual = ops.area_weighted_mean(
+ original_budget_residual, keepdim=True
+ )
+ original_budget_residual = original_budget_residual.cpu().numpy()
+ original_dry_air = (
+ ClimateData(data)
+ .surface_pressure_due_to_dry_air(vertical_coordinate)
+ .cpu()
+ .numpy()
+ )
+ assert np.any(np.abs(original_budget_residual) > 0.0)
+ in_data = {k: v.select(dim=1, index=0) for k, v in data.items()}
+ out_data = {k: v.select(dim=1, index=1) for k, v in data.items()}
+ fixed_out_data = _force_conserve_moisture(
+ in_data,
+ out_data,
+ vertical_coordinate=vertical_coordinate,
+ area_weighted_mean=ops.area_weighted_mean,
+ timestep=TIMESTEP,
+ terms_to_modify=terms_to_modify,
+ )
+ new_data = {
+ k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items()
+ }
+ new_budget_residual = total_water_path_budget_residual(
+ ClimateData(new_data),
+ vertical_coordinate=vertical_coordinate,
+ timestep=TIMESTEP,
+ )[:, 1] # no meaning for initial value data, want first timestep
+ new_dry_air = (
+ ClimateData(data)
+ .surface_pressure_due_to_dry_air(vertical_coordinate)
+ .cpu()
+ .numpy()
+ )
+
+ global_budget_residual = ops.area_weighted_mean(new_budget_residual).cpu().numpy()
+ np.testing.assert_almost_equal(global_budget_residual, 0.0, decimal=6)
+
+ if not global_only:
+ new_budget_residual = new_budget_residual.cpu().numpy()
+ assert np.all(np.abs(new_budget_residual) < np.abs(original_budget_residual))
+ np.testing.assert_almost_equal(new_budget_residual, 0.0, decimal=6)
+
+ np.testing.assert_almost_equal(new_dry_air, original_dry_air, decimal=6)
+
+
+def test_force_positive():
+ data = {
+ "foo": torch.tensor([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]),
+ "bar": torch.tensor([[-1.0, 0.0], [0.0, -3.0], [1.0, 2.0]]),
+ }
+ original_min = torch.min(data["foo"])
+ assert original_min < 0.0
+ fixed_data = force_positive(data, ["foo"])
+ new_min = torch.min(fixed_data["foo"])
+ # Ensure the minimum value of 'foo' is now 0
+ torch.testing.assert_close(new_min, torch.tensor(0.0))
+ # Ensure other variables are not modified
+ torch.testing.assert_close(fixed_data["bar"], data["bar"])
+
+
+def test_corrector_selector():
+ selector = CorrectorSelector(
+ type="ocean_corrector",
+ config={"masking": {"mask_name": "mask", "mask_value": 1}},
+ )
+ ops: GriddedOperations = LatLonOperations(1.0 + torch.rand(size=(5, 5)))
+ vertical: HybridSigmaPressureCoordinate = HybridSigmaPressureCoordinate(
+ ak=torch.tensor([1.0, 0.5, 0.0]), bk=torch.tensor([0.0, 0.5, 1.0])
+ )
+ corrector = selector.build(ops, vertical, TIMESTEP)
+ assert isinstance(corrector, OceanCorrector)
diff --git a/fme/fme/core/data_loading/config.py b/fme/fme/core/data_loading/config.py
deleted file mode 100644
index 735cb40..0000000
--- a/fme/fme/core/data_loading/config.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import dataclasses
-from typing import Literal, Optional, Sequence, Union
-
-import xarray as xr
-
-from fme.core.distributed import Distributed
-
-
-@dataclasses.dataclass
-class Slice:
- """
- Configuration of a python `slice` built-in.
-
- Required because `slice` cannot be initialized directly by dacite.
-
- Attributes:
- start: Start index of the slice.
- stop: Stop index of the slice.
- step: Step of the slice.
- """
-
- start: Optional[int] = None
- stop: Optional[int] = None
- step: Optional[int] = None
-
- @property
- def slice(self) -> slice:
- return slice(self.start, self.stop, self.step)
-
-
-@dataclasses.dataclass
-class TimeSlice:
- """
- Configuration of a slice of times. Step is an integer-valued index step.
-
- Note: start_time and stop_time may be provided as partial time strings and the
- stop_time will be included in the slice. See more details in `Xarray docs`_.
-
- Attributes:
- start_time: Start time of the slice.
- stop_time: Stop time of the slice.
- step: Step of the slice.
-
- .. _Xarray docs:
- https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range # noqa
- """
-
- start_time: Optional[str] = None
- stop_time: Optional[str] = None
- step: Optional[int] = None
-
- def slice(self, times: xr.CFTimeIndex) -> slice:
- return times.slice_indexer(self.start_time, self.stop_time, self.step)
-
-
-@dataclasses.dataclass
-class XarrayDataConfig:
- """
- Attributes:
- data_path: Path to the data.
- file_pattern: Glob pattern to match files in the data_path.
- n_repeats: Number of times to repeat the dataset (in time). It is up
- to the user to ensure that the input dataset to repeat results in
- data that is reasonably continuous across repetitions.
- engine: Backend for xarray.open_dataset. Currently supported options
- are "netcdf4" (the default) and "h5netcdf". Only valid when using
- XarrayDataset.
- spatial_dimensions: Specifies the spatial dimensions for the grid, default
- is lat/lon.
- subset: Slice defining a subset of the XarrayDataset to load. This can
- either be a `Slice` of integer indices or a `TimeSlice` of timestamps.
- infer_timestep: Whether to infer the timestep from the provided data.
- This should be set to True (the default) for ACE training. It may
- be useful to toggle this to False for applications like downscaling,
- which do not depend on the timestep of the data and therefore lack
- the additional requirement that the data be ordered and evenly
- spaced in time. It must be set to True if n_repeats > 1 in order
- to be able to infer the full time coordinate.
- """
-
- data_path: str
- file_pattern: str = "*.nc"
- n_repeats: int = 1
- engine: Optional[Literal["netcdf4", "h5netcdf", "zarr"]] = None
- spatial_dimensions: Literal["healpix", "latlon"] = "latlon"
- subset: Union[Slice, TimeSlice] = dataclasses.field(default_factory=Slice)
- infer_timestep: bool = True
-
- def __post_init__(self):
- if self.n_repeats > 1 and not self.infer_timestep:
- raise ValueError(
- "infer_timestep must be True if n_repeats is greater than 1"
- )
-
-
-@dataclasses.dataclass
-class DataLoaderConfig:
- """
- Attributes:
- dataset: A sequence of configurations each defining a dataset
- to be loaded. This sequence of datasets will be concatenated.
- batch_size: Number of samples per batch.
- num_data_workers: Number of parallel workers to use for data loading.
- prefetch_factor: how many batches a single data worker will attempt to
- hold in host memory at a given time.
- strict_ensemble: Whether to enforce that the ensemble members have the same
- dimensions and coordinates.
- """
-
- dataset: Sequence[XarrayDataConfig]
- batch_size: int
- num_data_workers: int
- prefetch_factor: Optional[int] = None
- strict_ensemble: bool = True
-
- def __post_init__(self):
- dist = Distributed.get_instance()
- if self.batch_size % dist.world_size != 0:
- raise ValueError(
- "batch_size must be divisible by the number of parallel "
- f"workers, got {self.batch_size} and {dist.world_size}"
- )
diff --git a/fme/fme/core/data_loading/data_typing.py b/fme/fme/core/data_loading/data_typing.py
deleted file mode 100644
index a039db1..0000000
--- a/fme/fme/core/data_loading/data_typing.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import abc
-import dataclasses
-import datetime
-from collections import namedtuple
-from typing import Dict, List, Literal, Mapping, Optional, Tuple
-
-import numpy as np
-import torch
-import xarray as xr
-from astropy_healpix import HEALPix
-
-from fme.core.typing_ import TensorDict, TensorMapping
-from fme.core.winds import lon_lat_to_xyz
-
-VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"])
-
-
-@dataclasses.dataclass
-class SigmaCoordinates:
- """
- Defines pressure at interface levels according to the following formula:
- p(k) = a(k) + b(k)*ps
-
- where ps is the surface pressure, a and b are the sigma coordinates.
-
- Attributes:
- ak: a(k) coefficients as a 1-dimensional tensor
- bk: b(k) coefficients as a 1-dimensional tensor
- """
-
- ak: torch.Tensor
- bk: torch.Tensor
-
- @property
- def coords(self) -> Mapping[str, np.ndarray]:
- return {"ak": self.ak.cpu().numpy(), "bk": self.bk.cpu().numpy()}
-
- def to(self, device: str) -> "SigmaCoordinates":
- return SigmaCoordinates(
- ak=self.ak.to(device),
- bk=self.bk.to(device),
- )
-
- def as_dict(self) -> TensorMapping:
- return {"ak": self.ak, "bk": self.bk}
-
-
-@dataclasses.dataclass
-class HorizontalCoordinates(abc.ABC):
- """
- Parent class for horizontal coordinate system grids.
- Contains coords which must be subclassed to provide the coordinates.
- """
-
- @property
- @abc.abstractmethod
- def coords(self) -> Mapping[str, np.ndarray]:
- pass
-
- @property
- @abc.abstractmethod
- def xyz(self) -> Tuple[float, float, float]:
- pass
-
- @property
- @abc.abstractmethod
- def dims(self) -> List[str]:
- pass
-
- @property
- @abc.abstractmethod
- def default_sizes(self) -> Dict[str, int]:
- pass
-
- @property
- @abc.abstractmethod
- def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]:
- pass
-
- # A temporary solution for training which allows us to aggregate along the
- # latitude dimension.
- # TODO: https://github.com/ai2cm/full-model/issues/1003
- @abc.abstractmethod
- def get_lat(self) -> torch.Tensor:
- pass
-
-
-@dataclasses.dataclass
-class LatLonCoordinates(HorizontalCoordinates):
- """
- Defines a (latitude, longitude) grid.
-
- Attributes:
- lat: 1-dimensional tensor of latitudes
- lon: 1-dimensional tensor of longitudes
- """
-
- lon: torch.Tensor
- lat: torch.Tensor
-
- lat_name: str = "lat"
- lon_name: str = "lon"
-
- @property
- def coords(self) -> Mapping[str, np.ndarray]:
- # TODO: Replace with lat/lon name?
- return {
- "lat": self.lat.cpu().numpy(),
- "lon": self.lon.cpu().numpy(),
- }
-
- @property
- def xyz(self) -> Tuple[float, float, float]:
- lats, lons = np.broadcast_arrays(self.lat[:, None], self.lon[None, :])
- return lon_lat_to_xyz(lons, lats)
-
- # TODO: https://github.com/ai2cm/full-model/issues/1003
- def get_lat(self) -> torch.Tensor:
- return self.lat
-
- @property
- def dims(self) -> List[str]:
- return [self.lat_name, self.lon_name]
-
- @property
- def default_sizes(self) -> Dict[str, int]:
- return {self.lat_name: 12, self.lon_name: 6}
-
- @property
- def grid(self) -> Literal["equiangular", "legendre-gauss"]:
- if torch.allclose(
- self.lat[1:] - self.lat[:-1],
- self.lat[1] - self.lat[0],
- ):
- return "equiangular"
- else:
- return "legendre-gauss"
-
-
-@dataclasses.dataclass
-class HEALPixCoordinates(HorizontalCoordinates):
- """
- Defines a HEALPix (face, height, width) grid. See https://healpix.jpl.nasa.gov/ for
- more information.
-
- Attributes:
- face: 1-dimensional tensor of faces
- height: 1-dimensional tensor of heights
- width: 1-dimensional tensor of widths
- """
-
- face: torch.Tensor
- height: torch.Tensor
- width: torch.Tensor
-
- @property
- def coords(self) -> Mapping[str, np.ndarray]:
- return {
- "face": self.face.cpu().numpy(),
- "height": self.height.cpu().numpy(),
- "width": self.width.cpu().numpy(),
- }
-
- @property
- def xyz(self) -> Tuple[float, float, float]:
- hp = HEALPix(nside=len(self.height), order="ring")
- return hp.healpix_to_xyz([self.face, self.height, self.width])
-
- @property
- def dims(self) -> List[str]:
- return ["face", "height", "width"]
-
- @property
- def default_sizes(cls) -> Dict[str, int]:
- return {"face": 12, "width": 64, "height": 64}
-
- # TODO: https://github.com/ai2cm/full-model/issues/1003
- # This is currently the dummy solution.
- def get_lat(self) -> torch.Tensor:
- raise NotImplementedError(
- "healpix does not support get_lat. If latitude is needed \
- for some reason, you may use this class's self.xyz property to derive it."
- )
-
- @property
- def grid(self) -> Literal["healpix"]:
- return "healpix"
-
-
-class Dataset(torch.utils.data.Dataset, abc.ABC):
- @abc.abstractproperty
- def metadata(self) -> Mapping[str, VariableMetadata]:
- ...
-
- @abc.abstractproperty
- def area_weights(self) -> torch.Tensor:
- ...
-
- @abc.abstractproperty
- def horizontal_coordinates(self) -> HorizontalCoordinates:
- ...
-
- @abc.abstractproperty
- def sigma_coordinates(self) -> SigmaCoordinates:
- ...
-
- @abc.abstractproperty
- def is_remote(self) -> bool:
- ...
-
- @abc.abstractmethod
- def get_sample_by_time_slice(
- self, time_slice: slice
- ) -> Tuple[TensorDict, xr.DataArray]:
- """
- Returns a sample of data for the given time slice.
-
- Args:
- time_slice: The time slice to return data for.
-
- Returns:
- A tuple whose first item is a mapping from variable
- name to tensor of shape [n_time, n_lat, n_lon] and
- whose second item is a time coordinate array.
- """
- ...
-
-
-@dataclasses.dataclass
-class GriddedData:
- """
- Data as required for pytorch training.
-
- The data is assumed to be gridded, and attributes are included for
- performing operations on gridded data.
-
- Attributes:
- loader: torch DataLoader, which returns batches of type
- TensorMapping where keys indicate variable name.
- Each tensor has shape
- [batch_size, time_window_size, n_channels, n_lat, n_lon].
- metadata: Metadata for each variable.
- area_weights: Weights for each grid cell, used for computing area-weighted
- averages. Has shape [n_lat, n_lon].
- sigma_coordinates: Sigma coordinates for each grid cell, used for computing
- pressure levels.
- horizontal_coordinates: Lat/lon coordinates for the data.
- timestep: Timestep of the model.
- sampler: Optional sampler for the data loader. Provided to allow support for
- distributed training.
- """
-
- loader: torch.utils.data.DataLoader
- metadata: Mapping[str, VariableMetadata]
- area_weights: torch.Tensor
- sigma_coordinates: SigmaCoordinates
- horizontal_coordinates: HorizontalCoordinates
- timestep: datetime.timedelta
- sampler: Optional[torch.utils.data.Sampler] = None
-
- @property
- def dataset(self) -> Dataset:
- return self.loader.dataset
-
- @property
- def coords(self) -> Mapping[str, np.ndarray]:
- return {
- **self.horizontal_coordinates.coords,
- **self.sigma_coordinates.coords,
- }
-
- @property
- def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]:
- """If the latitudes are equiangular, assume a regular grid and otherwise
- assume a gaussian, or 'legendre-gauss' grid."""
- return self.horizontal_coordinates.grid
diff --git a/fme/fme/core/data_loading/getters.py b/fme/fme/core/data_loading/getters.py
deleted file mode 100644
index 6a933c5..0000000
--- a/fme/fme/core/data_loading/getters.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import logging
-import warnings
-from typing import List, Sequence
-
-import numpy as np
-import torch.utils.data
-import xarray as xr
-from torch.utils.data.distributed import DistributedSampler
-from torch.utils.data.sampler import RandomSampler
-
-from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig
-from fme.core.device import using_gpu
-from fme.core.distributed import Distributed
-
-from ._xarray import XarrayDataset, as_index_slice, subset_dataset
-from .data_typing import GriddedData
-from .inference import (
- ExplicitIndices,
- ForcingDataLoaderConfig,
- InferenceDataLoaderConfig,
- InferenceDataset,
-)
-from .requirements import DataRequirements
-from .utils import BatchData
-
-
-def _all_same(iterable, cmp=lambda x, y: x == y):
- it = iter(iterable)
- try:
- first = next(it)
- except StopIteration:
- return True
- return all(cmp(first, rest) for rest in it)
-
-
-def get_datasets(
- dataset_configs: Sequence[XarrayDataConfig], requirements: DataRequirements
-) -> List[XarrayDataset]:
- datasets = []
- for config in dataset_configs:
- dataset = XarrayDataset(config, requirements)
- index_slice = as_index_slice(config.subset, dataset)
- dataset = subset_dataset(dataset, index_slice)
- datasets.append(dataset)
- return datasets
-
-
-def get_dataset(
- dataset_configs: Sequence[XarrayDataConfig],
- requirements: DataRequirements,
- strict: bool = True,
-) -> torch.utils.data.ConcatDataset[XarrayDataset]:
- datasets = get_datasets(dataset_configs, requirements)
-
- if not _all_same([d.metadata for d in datasets]):
- if strict:
- raise ValueError("Metadata for each ensemble member should be the same.")
- else:
- warnings.warn(
- "Metadata for each ensemble member are not the same. You may be "
- "concatenating incompatible datasets."
- )
- sigma_coords = [d.sigma_coordinates for d in datasets]
- ak, bk = list(
- zip(*[(s.ak.cpu().numpy(), s.bk.cpu().numpy()) for s in sigma_coords])
- )
- if not (_all_same(ak, cmp=np.allclose) and _all_same(bk, cmp=np.allclose)):
- if strict:
- raise ValueError(
- "Sigma coordinates for each ensemble member should be the same."
- )
- else:
- warnings.warn(
- "Vertical coordinates for each ensemble member are not the same. You "
- "may be concatenating incompatible datasets."
- )
-
- ensemble = torch.utils.data.ConcatDataset(datasets)
- ensemble.metadata = datasets[0].metadata # type: ignore
- ensemble.area_weights = datasets[0].area_weights # type: ignore
- ensemble.sigma_coordinates = datasets[0].sigma_coordinates # type: ignore
- ensemble.timestep = datasets[0].timestep # type: ignore
- ensemble.horizontal_coordinates = datasets[0].horizontal_coordinates # type: ignore
- ensemble.is_remote = any(d.is_remote for d in datasets) # type: ignore
- return ensemble
-
-
-def get_data_loader(
- config: DataLoaderConfig,
- train: bool,
- requirements: DataRequirements,
-) -> GriddedData:
- """
- Args:
- config: Parameters for the data loader.
- train: Whether loader is intended for training or validation data; if True,
- then data will be shuffled.
- requirements: Data requirements for the model.
- """
- dataset = get_dataset(config.dataset, requirements, strict=config.strict_ensemble)
- dist = Distributed.get_instance()
-
- if dist.is_distributed():
- sampler = DistributedSampler(dataset, shuffle=train)
- else:
- sampler = RandomSampler(dataset) if train else None
-
- if dataset.is_remote:
- # GCSFS and S3FS are not fork-safe, so we need to use forkserver
- mp_context = "forkserver"
- persistent_workers = True
- else:
- mp_context = None
- persistent_workers = False
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=dist.local_batch_size(int(config.batch_size)),
- num_workers=config.num_data_workers,
- sampler=sampler,
- drop_last=True,
- pin_memory=using_gpu(),
- collate_fn=BatchData.from_sample_tuples,
- prefetch_factor=config.prefetch_factor,
- multiprocessing_context=mp_context,
- persistent_workers=persistent_workers,
- )
-
- if len(dataloader) == 0:
- raise ValueError(
- "No batches in dataloader: "
- f"{len(dataloader.dataset)} samples, {len(dataloader)} batches. "
- f"Batch size is {dataloader.batch_size}"
- )
-
- return GriddedData(
- loader=dataloader,
- metadata=dataset.metadata,
- area_weights=dataset.area_weights,
- sampler=sampler,
- sigma_coordinates=dataset.sigma_coordinates,
- timestep=dataset.timestep,
- horizontal_coordinates=dataset.horizontal_coordinates,
- )
-
-
-def get_inference_data(
- config: InferenceDataLoaderConfig,
- forward_steps_in_memory: int,
- requirements: DataRequirements,
-) -> GriddedData:
- """
- Args:
- config: Parameters for the data loader.
- forward_steps_in_memory: Number of forward steps to keep in memory at once.
- requirements: Data requirements for the model.
-
- Returns:
- A data loader for inference with coordinates and metadata.
- """
- dataset = InferenceDataset(config, forward_steps_in_memory, requirements)
-
- if dataset.is_remote:
- # GCSFS and S3FS are not fork-safe, so we need to use forkserver
- # persist workers since startup is slow
- mp_context = "forkserver"
- persistent_workers = True
- else:
- mp_context = None
- persistent_workers = False
-
- logging.info(f"Multiprocessing inference context: {mp_context or 'fork'}")
-
- # we roll our own batching in InferenceDataset, which is why batch_size=None below
- loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=None,
- num_workers=config.num_data_workers,
- shuffle=False,
- pin_memory=using_gpu(),
- multiprocessing_context=mp_context,
- persistent_workers=persistent_workers,
- )
- return GriddedData(
- loader=loader,
- metadata=dataset.metadata,
- area_weights=dataset.area_weights,
- sigma_coordinates=dataset.sigma_coordinates,
- timestep=dataset.timestep,
- horizontal_coordinates=dataset.horizontal_coordinates,
- )
-
-
-def get_forcing_data(
- config: ForcingDataLoaderConfig,
- forward_steps_in_memory: int,
- requirements: DataRequirements,
- initial_times: xr.DataArray,
-) -> GriddedData:
- """Return a GriddedData loader for forcing data only. This function determines the
- start indices for the forcing data based on the initial times provided.
-
- Args:
- config: Parameters for the forcing data loader.
- forward_steps_in_memory: Number of forward steps to provide per window of
- forcing data that will be returned by loader.
- requirements: Data requirements for the forcing data.
- initial_times: Desired initial times for the forcing data. This must be a 1D
- data array, whose length determines the ensemble size.
-
- Returns:
- A data loader for forcing data with coordinates and metadata.
- """
- available_times = XarrayDataset(config.dataset, requirements).all_times
- start_time_indices = []
- for time in initial_times.values:
- start_time_indices.append(available_times.get_loc(time))
- inference_config = config.build_inference_config(
- start_indices=ExplicitIndices(start_time_indices)
- )
- return get_inference_data(inference_config, forward_steps_in_memory, requirements)
diff --git a/fme/fme/core/data_loading/inference.py b/fme/fme/core/data_loading/inference.py
deleted file mode 100644
index 1670f0e..0000000
--- a/fme/fme/core/data_loading/inference.py
+++ /dev/null
@@ -1,240 +0,0 @@
-import dataclasses
-import datetime
-from math import ceil
-from typing import Sequence, Union
-
-import cftime
-import numpy as np
-import torch
-import xarray as xr
-
-from fme.core.data_loading._xarray import XarrayDataset
-from fme.core.data_loading.config import Slice, XarrayDataConfig
-from fme.core.data_loading.data_typing import HorizontalCoordinates, SigmaCoordinates
-from fme.core.data_loading.requirements import DataRequirements
-from fme.core.data_loading.utils import BatchData
-from fme.core.distributed import Distributed
-
-
-@dataclasses.dataclass
-class TimestampList:
- """
- Configuration for a list of timestamps.
-
- Attributes:
- times: List of timestamps.
- timestamp_format: Format of the timestamps.
- """
-
- times: Sequence[str]
- timestamp_format: str = "%Y-%m-%dT%H:%M:%S"
-
- def as_indices(self, time_index: xr.CFTimeIndex) -> np.ndarray:
- datetimes = [
- cftime.datetime.strptime(
- t, self.timestamp_format, calendar=time_index.calendar
- )
- for t in self.times
- ]
- (indices,) = time_index.isin(datetimes).nonzero()
- if len(indices) != len(self.times):
- missing_times = set(datetimes) - set(time_index[indices])
- raise ValueError(
- f"Inference initial condition timestamps {missing_times} "
- "were not found in the dataset."
- )
- return indices
-
- @property
- def n_initial_conditions(self) -> int:
- return len(self.times)
-
-
-@dataclasses.dataclass
-class InferenceInitialConditionIndices:
- """
- Configuration of the indices for initial conditions during inference.
-
- Attributes:
- n_initial_conditions: Number of initial conditions to use.
- first: Index of the first initial condition.
- interval: Interval between initial conditions.
- """
-
- n_initial_conditions: int
- first: int = 0
- interval: int = 1
-
- def __post_init__(self):
- if self.interval < 0:
- raise ValueError("interval must be positive")
-
- def as_indices(self) -> np.ndarray:
- stop = self.n_initial_conditions * self.interval + self.first
- return np.arange(self.first, stop, self.interval)
-
-
-@dataclasses.dataclass
-class ExplicitIndices:
- """
- Configure indices providing them explicitly.
-
- Attributes:
- list: List of integer indices.
- """
-
- list: Sequence[int]
-
- def as_indices(self) -> np.ndarray:
- return np.array(self.list)
-
- @property
- def n_initial_conditions(self) -> int:
- return len(self.list)
-
-
-@dataclasses.dataclass
-class InferenceDataLoaderConfig:
- """
- Configuration for inference data.
-
- This is like the `DataLoaderConfig` class, but with some additional
- constraints. During inference, we have only one batch, so the number of
- samples directly determines the size of that batch.
-
- Attributes:
- dataset: Configuration to define the dataset.
- start_indices: Configuration of the indices for initial conditions
- during inference. This can be a list of timestamps, a list of
- integer indices, or a slice configuration of the integer indices.
- Values following the initial condition will still come from
- the full dataset.
- num_data_workers: Number of parallel workers to use for data loading.
- """
-
- dataset: XarrayDataConfig
- start_indices: Union[
- InferenceInitialConditionIndices, ExplicitIndices, TimestampList
- ]
- num_data_workers: int = 0
-
- def __post_init__(self):
- if self.dataset.subset != Slice(None, None, None):
- raise ValueError("Inference data may not be subset.")
-
- @property
- def n_samples(self) -> int:
- return self.start_indices.n_initial_conditions
-
-
-@dataclasses.dataclass
-class ForcingDataLoaderConfig:
- """
- Configuration for the forcing data.
-
- Attributes:
- dataset: Configuration to define the dataset.
- num_data_workers: Number of parallel workers to use for data loading.
- """
-
- dataset: XarrayDataConfig
- num_data_workers: int = 0
-
- def __post_init__(self):
- if self.dataset.subset != Slice(None, None, None):
- raise ValueError("Inference data may not be subset.")
-
- def build_inference_config(self, start_indices: ExplicitIndices):
- return InferenceDataLoaderConfig(
- dataset=self.dataset,
- num_data_workers=self.num_data_workers,
- start_indices=start_indices,
- )
-
-
-class InferenceDataset(torch.utils.data.Dataset):
- def __init__(
- self,
- config: InferenceDataLoaderConfig,
- forward_steps_in_memory: int,
- requirements: DataRequirements,
- ):
- dataset = XarrayDataset(config.dataset, requirements=requirements)
- self._dataset = dataset
- self._sigma_coordinates = dataset.sigma_coordinates
- self._metadata = dataset.metadata
- self._area_weights = dataset.area_weights
- self._horizontal_coordinates = dataset.horizontal_coordinates
- self._timestep = dataset.timestep
- self._forward_steps_in_memory = forward_steps_in_memory
- self._total_steps = requirements.n_timesteps - 1
- self._is_remote = dataset.is_remote
- self.n_samples = config.n_samples # public attribute
- if isinstance(config.start_indices, TimestampList):
- self._start_indices = config.start_indices.as_indices(dataset.all_times)
- else:
- self._start_indices = config.start_indices.as_indices()
- self._validate_n_forward_steps()
-
- def __getitem__(self, index) -> BatchData:
- dist = Distributed.get_instance()
- i_start = index * self._forward_steps_in_memory
- sample_tuples = []
- for i_sample in range(self.n_samples):
- # check if sample is one this local rank should process
- if i_sample % dist.world_size != dist.rank:
- continue
- i_window_start = i_start + self._start_indices[i_sample]
- i_window_end = i_window_start + self._forward_steps_in_memory + 1
- if i_window_end > (self._total_steps + self._start_indices[i_sample]):
- i_window_end = self._total_steps + self._start_indices[i_sample] + 1
- window_time_slice = slice(i_window_start, i_window_end)
- sample_tuples.append(
- self._dataset.get_sample_by_time_slice(window_time_slice)
- )
- result = BatchData.from_sample_tuples(sample_tuples)
- assert result.times.shape[0] == self.n_samples // dist.world_size
- return result
-
- def __len__(self) -> int:
- # The ceil is necessary so if the last batch is smaller
- # than the rest the ratio will be rounded up and the last batch
- # will be included in the loading
- return int(ceil(self._total_steps / self._forward_steps_in_memory))
-
- @property
- def sigma_coordinates(self) -> SigmaCoordinates:
- return self._sigma_coordinates
-
- @property
- def metadata(self) -> xr.Dataset:
- return self._metadata
-
- @property
- def area_weights(self) -> xr.DataArray:
- return self._area_weights
-
- @property
- def horizontal_coordinates(self) -> HorizontalCoordinates:
- return self._horizontal_coordinates
-
- @property
- def timestep(self) -> datetime.timedelta:
- return self._timestep
-
- @property
- def is_remote(self) -> bool:
- return self._is_remote
-
- @property
- def n_forward_steps(self) -> int:
- return self._total_steps
-
- def _validate_n_forward_steps(self):
- max_steps = self._dataset.total_timesteps - self._start_indices[-1] - 1
- if self._total_steps > max_steps:
- raise ValueError(
- f"The number of forward inference steps ({self._total_steps}) must "
- f"be less than or equal to the number of possible steps ({max_steps})"
- f"in dataset after the last initial condition's start index."
- )
diff --git a/fme/fme/core/data_loading/requirements.py b/fme/fme/core/data_loading/requirements.py
deleted file mode 100644
index 7aecf19..0000000
--- a/fme/fme/core/data_loading/requirements.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import dataclasses
-from typing import List
-
-
-@dataclasses.dataclass
-class DataRequirements:
- names: List[str]
- n_timesteps: int
diff --git a/fme/fme/core/dataset/__init__.py b/fme/fme/core/dataset/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/fme/fme/core/dataset/config.py b/fme/fme/core/dataset/config.py
new file mode 100644
index 0000000..a7ecbe6
--- /dev/null
+++ b/fme/fme/core/dataset/config.py
@@ -0,0 +1,280 @@
+import dataclasses
+from datetime import timedelta
+from typing import Literal, Mapping, Optional, Union
+
+import numpy as np
+import pandas as pd
+import torch
+import xarray as xr
+
+from fme.core.typing_ import Slice, TensorDict
+
+
+@dataclasses.dataclass
+class TimeSlice:
+ """
+ Configuration of a slice of times. Step is an integer-valued index step.
+
+ Note: start_time and stop_time may be provided as partial time strings and the
+ stop_time will be included in the slice. See more details in `Xarray docs`_.
+
+ Parameters:
+ start_time: Start time of the slice.
+ stop_time: Stop time of the slice.
+ step: Step of the slice.
+
+ .. _Xarray docs:
+ https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range
+ """ # noqa: E501
+
+ start_time: Optional[str] = None
+ stop_time: Optional[str] = None
+ step: Optional[int] = None
+
+ def slice(self, time: xr.CFTimeIndex) -> slice:
+ return time.slice_indexer(self.start_time, self.stop_time, self.step)
+
+
+def _convert_interval_to_int(
+ interval: pd.Timedelta,
+ timestep: timedelta,
+):
+ """Convert interval to integer number of timesteps."""
+ if interval % timestep != timedelta(0):
+ raise ValueError(
+ f"Requested interval length {interval} is not a "
+ f"multiple of the timestep {timestep}."
+ )
+
+ return interval // timestep
+
+
+@dataclasses.dataclass
+class RepeatedInterval:
+ """
+ Configuration for a repeated interval within a block. This configuration
+ is used to generate a boolean mask for a dataset that will return values
+ within the interval and repeat that throughout the dataset.
+
+ Parameters:
+ interval_length: Length of the interval to return values from
+ start: Start position of the interval within the repeat block.
+ block_length: Total length of the block to be repeated over the length of
+ the dataset, including the interval length.
+
+ Note:
+ The interval_length, start, and block_length can be provided as either
+ all integers or all strings representing timedeltas of the block.
+ If provided as strings, the timestep must be provided when calling
+ `get_boolean_mask`.
+
+ Examples:
+ To return values from the first 3 items of every 6 items, use:
+
+ >>> RepeatedInterval(interval_length=3, repeat=6, start=0)
+
+ To return a days worth of values starting after 2 days from every 7-day
+ block, use:
+
+ >>> RepeatedInterval(interval_length="1d", repeat="7d", start="2d")
+ """
+
+ interval_length: Union[int, str]
+ start: Union[int, str]
+ block_length: Union[int, str]
+
+ def __post_init__(self):
+ types = {type(self.interval_length), type(self.block_length), type(self.start)}
+ if len(types) > 1:
+ raise ValueError(
+ "All attributes of RepeatedInterval must be of the "
+ "same type (either all int or all str)."
+ )
+
+ self._is_time_delta_str = isinstance(self.interval_length, str)
+
+ if self._is_time_delta_str:
+ self.interval_length = pd.Timedelta(self.interval_length)
+ self.block_length = pd.Timedelta(self.block_length)
+ self.start = pd.Timedelta(self.start)
+
+ def get_boolean_mask(
+ self, length: int, timestep: Optional[timedelta] = None
+ ) -> np.ndarray:
+ """
+ Return a boolean mask for the repeated interval.
+
+ Args:
+ length: Length of the dataset.
+ timestep: Timestep of the dataset.
+ """
+ if self._is_time_delta_str:
+ if timestep is None:
+ raise ValueError(
+ "Timestep must be provided when using time deltas "
+ "for RepeatedInterval."
+ )
+
+ interval_length = _convert_interval_to_int(self.interval_length, timestep)
+ block_length = _convert_interval_to_int(self.block_length, timestep)
+ start = _convert_interval_to_int(self.start, timestep)
+ else:
+ interval_length = self.interval_length
+ block_length = self.block_length
+ start = self.start
+
+ if start + interval_length > block_length:
+ raise ValueError(
+ "The interval (with start point) must fit within the repeat block."
+ )
+
+ block = np.zeros(block_length, dtype=bool)
+ block[start : start + interval_length] = True
+ num_blocks = length // block_length + 1
+ mask = np.tile(block, num_blocks)[:length]
+ return mask
+
+
+@dataclasses.dataclass
+class OverwriteConfig:
+ """Configuration to overwrite field values in XarrayDataset.
+
+ Parameters:
+ constant: Fill field with constant value.
+ multiply_scalar: Multiply field by scalar value.
+ """
+
+ constant: Mapping[str, float] = dataclasses.field(default_factory=dict)
+ multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self):
+ key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys())
+ if key_overlap:
+ raise ValueError(
+ "OverwriteConfig cannot have the same variable in both constant "
+ f"and multiply_scalar: {key_overlap}"
+ )
+
+ def apply(self, tensors: TensorDict) -> TensorDict:
+ for var, fill_value in self.constant.items():
+ data = tensors[var]
+ tensors[var] = torch.ones_like(data) * torch.tensor(
+ fill_value, dtype=data.dtype, device=data.device
+ )
+ for var, multiplier in self.multiply_scalar.items():
+ data = tensors[var]
+ tensors[var] = data * torch.tensor(
+ multiplier, dtype=data.dtype, device=data.device
+ )
+ return tensors
+
+ @property
+ def variables(self):
+ return set(self.constant.keys()) | set(self.multiply_scalar.keys())
+
+
+@dataclasses.dataclass
+class FillNaNsConfig:
+ """
+ Configuration to fill NaNs with a constant value or others.
+
+ Parameters:
+ method: Type of fill operation. Currently only 'constant' is supported.
+ value: Value to fill NaNs with.
+ """
+
+ method: Literal["constant"] = "constant"
+ value: float = 0.0
+
+
+@dataclasses.dataclass
+class XarrayDataConfig:
+ """
+ Parameters:
+ data_path: Path to the data.
+ file_pattern: Glob pattern to match files in the data_path.
+ n_repeats: Number of times to repeat the dataset (in time). It is up
+ to the user to ensure that the input dataset to repeat results in
+ data that is reasonably continuous across repetitions.
+ engine: Backend used in xarray.open_dataset call.
+ spatial_dimensions: Specifies the spatial dimensions for the grid, default
+ is lat/lon.
+ subset: Slice defining a subset of the XarrayDataset to load. This can
+ either be a `Slice` of integer indices or a `TimeSlice` of timestamps.
+ infer_timestep: Whether to infer the timestep from the provided data.
+ This should be set to True (the default) for ACE training. It may
+ be useful to toggle this to False for applications like downscaling,
+ which do not depend on the timestep of the data and therefore lack
+ the additional requirement that the data be ordered and evenly
+ spaced in time. It must be set to True if n_repeats > 1 in order
+ to be able to infer the full time coordinate.
+ dtype: Data type to cast the data to. If None, no casting is done. It is
+ required that 'torch.{dtype}' is a valid dtype.
+ overwrite: Optional OverwriteConfig to overwrite loaded field values. If this is
+ configured for a renamed field, the key should be the final updated name.
+ renamed_variables: Optional mapping of {old_name: new_name} to rename variables
+ fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value.
+
+ Examples:
+ If data is stored in a directory with multiple netCDF files which can be
+ concatenated along the time dimension, use:
+
+ >>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT
+
+ If data is stored in a single zarr store at ``/some/directory/dataset.zarr``,
+ use:
+
+ >>> fme.ace.XarrayDataConfig(
+ ... data_path="/some/directory",
+ ... file_pattern="dataset.zarr",
+ ... engine="zarr"
+ ... ) # doctest: +IGNORE_OUTPUT
+ """ # noqa: E501
+
+ data_path: str
+ file_pattern: str = "*.nc"
+ n_repeats: int = 1
+ engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4"
+ spatial_dimensions: Literal["healpix", "latlon"] = "latlon"
+ subset: Union[Slice, TimeSlice, RepeatedInterval] = dataclasses.field(
+ default_factory=Slice
+ )
+ infer_timestep: bool = True
+ dtype: Optional[str] = "float32"
+ overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig)
+ renamed_variables: Optional[Mapping[str, str]] = None
+ fill_nans: Optional[FillNaNsConfig] = None
+
+ def _default_file_pattern_check(self):
+ if self.engine == "zarr" and self.file_pattern == "*.nc":
+ raise ValueError(
+ "The file pattern is set to the default NetCDF file pattern *.nc "
+ "but the engine is specified as 'zarr'. Please set "
+ "`XarrayDataConfig.file_pattern` to match the zarr filename."
+ )
+
+ def __post_init__(self):
+ if self.n_repeats > 1 and not self.infer_timestep:
+ raise ValueError(
+ "infer_timestep must be True if n_repeats is greater than 1"
+ )
+ if self.dtype is None:
+ self.torch_dtype = None
+ else:
+ try:
+ self.torch_dtype = getattr(torch, self.dtype)
+ except AttributeError:
+ raise ValueError(f"Invalid dtype '{self.dtype}'")
+ if not isinstance(self.torch_dtype, torch.dtype):
+ raise ValueError(f"Invalid dtype '{self.dtype}'")
+
+ # Raise error if overwrite variables are in the keys of renamed variables
+ if self.renamed_variables is not None:
+ overlap = set(self.overwrite.variables) & set(self.renamed_variables.keys())
+ if overlap:
+ raise ValueError(
+ "Variables in overwrite should not be the original names before "
+ f"renaming: {overlap}. "
+ "Please use the final renamed variables in the overwrite config."
+ )
+ self._default_file_pattern_check()
diff --git a/fme/fme/core/dataset/data_typing.py b/fme/fme/core/dataset/data_typing.py
new file mode 100644
index 0000000..3cd96b4
--- /dev/null
+++ b/fme/fme/core/dataset/data_typing.py
@@ -0,0 +1,29 @@
+import abc
+from collections import namedtuple
+from typing import Tuple
+
+import torch
+import xarray as xr
+
+from fme.core.typing_ import TensorDict
+
+VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"])
+
+
+class Dataset(torch.utils.data.Dataset, abc.ABC):
+ @abc.abstractmethod
+ def get_sample_by_time_slice(
+ self, time_slice: slice
+ ) -> Tuple[TensorDict, xr.DataArray]:
+ """
+ Returns a sample of data for the given time slice.
+
+ Args:
+ time_slice: The time slice to return data for.
+
+ Returns:
+ A tuple whose first item is a mapping from variable
+ name to tensor of shape [n_time, n_lat, n_lon] and
+ whose second item is a time coordinate array.
+ """
+ ...
diff --git a/fme/fme/core/dataset/getters.py b/fme/fme/core/dataset/getters.py
new file mode 100644
index 0000000..e85e9cc
--- /dev/null
+++ b/fme/fme/core/dataset/getters.py
@@ -0,0 +1,46 @@
+import warnings
+from typing import List, Optional, Sequence, Tuple
+
+import torch.utils.data
+
+from fme.core.dataset.config import XarrayDataConfig
+from fme.core.dataset.xarray import DatasetProperties, XarrayDataset, get_xarray_dataset
+
+from .requirements import DataRequirements
+
+
+def get_datasets(
+ dataset_configs: Sequence[XarrayDataConfig],
+ requirements: DataRequirements,
+ strict: bool = True,
+) -> Tuple[List[XarrayDataset], DatasetProperties]:
+ datasets = []
+ properties: Optional[DatasetProperties] = None
+ for config in dataset_configs:
+ dataset, new_properties = get_xarray_dataset(config, requirements)
+ datasets.append(dataset)
+ if properties is None:
+ properties = new_properties
+ elif not strict:
+ try:
+ properties.update(new_properties)
+ except ValueError as e:
+ warnings.warn(
+ f"Metadata for each ensemble member are not the same: {e}"
+ )
+ else:
+ properties.update(new_properties)
+ if properties is None:
+ raise ValueError("At least one dataset must be provided.")
+
+ return datasets, properties
+
+
+def get_dataset(
+ dataset_configs: Sequence[XarrayDataConfig],
+ requirements: DataRequirements,
+ strict: bool = True,
+) -> Tuple[torch.utils.data.ConcatDataset[XarrayDataset], DatasetProperties]:
+ datasets, properties = get_datasets(dataset_configs, requirements, strict=strict)
+ ensemble = torch.utils.data.ConcatDataset(datasets)
+ return ensemble, properties
diff --git a/fme/fme/core/dataset/requirements.py b/fme/fme/core/dataset/requirements.py
new file mode 100644
index 0000000..65e2ab6
--- /dev/null
+++ b/fme/fme/core/dataset/requirements.py
@@ -0,0 +1,16 @@
+import dataclasses
+from typing import List
+
+
+@dataclasses.dataclass
+class DataRequirements:
+ """
+ The requirements for batches (time windows) of loaded data.
+
+ Parameters:
+ names: Names of the variables to load.
+ n_timesteps: Number of timesteps to load in each batch window.
+ """
+
+ names: List[str]
+ n_timesteps: int
diff --git a/fme/fme/core/data_loading/test_utils.py b/fme/fme/core/dataset/test_utils.py
similarity index 75%
rename from fme/fme/core/data_loading/test_utils.py
rename to fme/fme/core/dataset/test_utils.py
index e82f2f6..2091bc7 100644
--- a/fme/fme/core/data_loading/test_utils.py
+++ b/fme/fme/core/dataset/test_utils.py
@@ -1,17 +1,20 @@
import copy
import datetime
+from typing import List
import numpy as np
import pytest
import torch
import xarray as xr
-from fme.core.data_loading.data_typing import (
+from fme.core.coordinates import (
+ DimSize,
HEALPixCoordinates,
HorizontalCoordinates,
LatLonCoordinates,
)
-from fme.core.data_loading.utils import (
+
+from .utils import (
_get_indexers,
as_broadcasted_tensor,
decode_timestep,
@@ -31,12 +34,12 @@ def get_sizes(
spatial_dims: HorizontalCoordinates = LatLonCoordinates(
lon=torch.Tensor(np.arange(6)),
lat=torch.Tensor(np.arange(12)),
- lat_name=LAT_DIM,
- lon_name=LON_DIM,
- )
+ loaded_lat_name=LAT_DIM,
+ loaded_lon_name=LON_DIM,
+ ),
):
- spatial_sizes: dict = copy.deepcopy(spatial_dims.default_sizes)
- spatial_sizes[TIME_DIM] = 3
+ spatial_sizes: List[DimSize] = copy.deepcopy(spatial_dims.loaded_default_sizes)
+ spatial_sizes.append(DimSize(TIME_DIM, 3))
return spatial_sizes
@@ -44,13 +47,13 @@ def create_reference_dataset(
spatial_dims: HorizontalCoordinates = LatLonCoordinates(
lon=torch.Tensor(np.arange(6)),
lat=torch.Tensor(np.arange(12)),
- lat_name=LAT_DIM,
- lon_name=LON_DIM,
- )
+ loaded_lat_name=LAT_DIM,
+ loaded_lon_name=LON_DIM,
+ ),
):
- dims = [TIME_DIM] + spatial_dims.dims
- sizes = get_sizes(spatial_dims=spatial_dims)
- shape = tuple(sizes[dim] for dim in dims)
+ dims = [TIME_DIM] + spatial_dims.loaded_dims
+ dim_sizes = get_sizes(spatial_dims=spatial_dims)
+ shape = tuple(dim_size.size for dim_size in dim_sizes)
data = np.arange(np.prod(shape)).reshape(shape)
coords = [np.arange(size) for size in shape]
full = xr.DataArray(data, dims=dims, coords=coords, name=FULL_NAME)
@@ -71,11 +74,13 @@ def test_infer_horizontal_dimension_names(lon_dim, lat_dim, warns):
spatial_dims = LatLonCoordinates(
lon=torch.Tensor(np.arange(6)),
lat=torch.Tensor(np.arange(12)),
- lat_name=lat_dim,
- lon_name=lon_dim,
+ loaded_lat_name=lat_dim,
+ loaded_lon_name=lon_dim,
)
ds = create_reference_dataset(spatial_dims=spatial_dims)
expected = [lon_dim, lat_dim]
+ for dim in expected:
+ assert dim in ds.dims
if warns:
with pytest.warns(UserWarning, match="Familiar"):
infer_horizontal_dimension_names(ds)
@@ -91,18 +96,32 @@ def test_infer_horizontal_dimension_names_healpix():
width=torch.Tensor(np.arange(64)),
height=torch.Tensor(np.arange(64)),
)
- ds = create_reference_dataset(hpx_coords)
- expected = hpx_coords.dims
+ ds = create_reference_dataset(spatial_dims=hpx_coords)
+ expected = hpx_coords.loaded_dims
result = infer_horizontal_dimension_names(ds)
assert result == expected
+@pytest.mark.parametrize(
+ ["coordinate_type", "coord_sizes"],
+ [
+ pytest.param(LatLonCoordinates, {"lat": 90, "lon": 180}),
+ pytest.param(HEALPixCoordinates, {"face": 12, "height": 64, "width": 64}),
+ ],
+)
+def test_horizonal_dimension_sizes(coordinate_type, coord_sizes):
+ coords = {name: torch.Tensor(np.arange(size)) for name, size in coord_sizes.items()}
+ horizontal_coords = coordinate_type(**coords)
+ for name, size in coord_sizes.items():
+ assert len(horizontal_coords.coords[name]) == size
+
+
def test_infer_horizontal_dimension_names_error():
spatial_dims = LatLonCoordinates(
lon=torch.Tensor(np.arange(6)),
lat=torch.Tensor(np.arange(12)),
- lat_name="foo",
- lon_name="bar",
+ loaded_lat_name="foo",
+ loaded_lon_name="bar",
)
ds = create_reference_dataset(spatial_dims=spatial_dims)
ds = ds.isel(time=0)
@@ -123,8 +142,10 @@ def test_infer_horizontal_dimension_names_error():
ids=lambda x: f"{x}",
)
def test__get_indexers(variable_dims, expected):
- sizes = get_sizes()
- shape = tuple(sizes[dim] for dim in variable_dims)
+ dim_sizes = get_sizes()
+ shape = tuple(
+ dim_size.size for dim_size in dim_sizes if dim_size.name in variable_dims
+ )
variable = xr.Variable(variable_dims, np.zeros(shape))
dims = (TIME_DIM, LAT_DIM, LON_DIM)
result = _get_indexers(variable, dims)
diff --git a/fme/fme/core/data_loading/test__xarray.py b/fme/fme/core/dataset/test_xarray.py
similarity index 72%
rename from fme/fme/core/data_loading/test__xarray.py
rename to fme/fme/core/dataset/test_xarray.py
index 014c5bb..14ba0fa 100755
--- a/fme/fme/core/data_loading/test__xarray.py
+++ b/fme/fme/core/dataset/test_xarray.py
@@ -12,29 +12,30 @@
import torch
import xarray as xr
-from fme.core.data_loading._xarray import (
+from fme.core.coordinates import LatLonCoordinates
+from fme.core.dataset.config import (
+ FillNaNsConfig,
+ OverwriteConfig,
+ TimeSlice,
+ XarrayDataConfig,
+)
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.dataset.xarray import (
XarrayDataset,
get_cumulative_timesteps,
get_file_local_index,
get_raw_times,
get_timestep,
- repeat_and_increment_times,
-)
-from fme.core.data_loading.config import (
- DataLoaderConfig,
- Slice,
- TimeSlice,
- XarrayDataConfig,
-)
-from fme.core.data_loading.getters import get_data_loader, get_dataset
-from fme.core.data_loading.requirements import DataRequirements
-from fme.core.data_loading.utils import (
- as_broadcasted_tensor,
- infer_horizontal_dimension_names,
+ get_xarray_dataset,
+ repeat_and_increment_time,
)
+from fme.core.typing_ import Slice
+
+from .utils import as_broadcasted_tensor, infer_horizontal_dimension_names
SLICE_NONE = slice(None)
MOCK_DATA_FREQ = "3h"
+MOCK_DATA_START_DATE = "2003-03"
@dataclasses.dataclass
@@ -75,6 +76,7 @@ def _get_data(
file_freq,
step_freq,
calendar,
+ with_nans=False,
) -> MockData:
"""Constructs an xarray dataset and saves to disk in netcdf format."""
obs_times = xr.cftime_range(
@@ -109,15 +111,17 @@ def _get_data(
last = start_times[i + 1]
else:
last = obs_times[-1] + obs_delta
- times = xr.cftime_range(
+ time = xr.cftime_range(
first, last, freq=step_freq, calendar=calendar, inclusive="left"
)
data_vars: Dict[str, Union[float, xr.DataArray]] = {**ak, **bk}
for var_name in var_names:
- data = np.random.randn(len(times), n_lat, n_lon).astype(np.float32)
+ data = np.random.randn(len(time), n_lat, n_lon).astype(np.float32)
+ if with_nans:
+ data[0, :, 0] = np.nan
data_vars[var_name] = xr.DataArray(data, dims=("time", "lat", "lon"))
- data_varying_scalar = np.random.randn(len(times)).astype(np.float32)
+ data_varying_scalar = np.random.randn(len(time)).astype(np.float32)
data_vars["varying_scalar_var"] = xr.DataArray(
data_varying_scalar, dims=("time",)
)
@@ -126,7 +130,7 @@ def _get_data(
data_vars["constant_scalar_var"] = constant_scalar_var
coords = {
- "time": xr.DataArray(times, dims=("time",)),
+ "time": xr.DataArray(time, dims=("time",)),
"lat": xr.DataArray(np.arange(n_lat, dtype=np.float32), dims=("lat",)),
"lon": xr.DataArray(np.arange(n_lon, dtype=np.float32), dims=("lon",)),
}
@@ -151,15 +155,16 @@ def _get_data(
return MockData(tmpdir, obs_times, start_times, start_indices, variable_names)
-def get_mock_monthly_netcdfs(tmp_path_factory, dirname) -> MockData:
+def get_mock_monthly_netcdfs(tmp_path_factory, dirname, with_nans=False) -> MockData:
return _get_data(
tmp_path_factory,
dirname,
- start="2003-03",
+ start=MOCK_DATA_START_DATE,
end="2003-06",
file_freq="MS",
step_freq=MOCK_DATA_FREQ,
calendar="standard",
+ with_nans=with_nans,
)
@@ -168,6 +173,11 @@ def mock_monthly_netcdfs(tmp_path_factory) -> MockData:
return get_mock_monthly_netcdfs(tmp_path_factory, "month")
+@pytest.fixture(scope="session")
+def mock_monthly_netcdfs_with_nans(tmp_path_factory) -> MockData:
+ return get_mock_monthly_netcdfs(tmp_path_factory, "month_with_nans", with_nans=True)
+
+
@pytest.fixture(scope="session")
def mock_monthly_zarr(tmp_path_factory, mock_monthly_netcdfs) -> MockData:
zarr_parent = tmp_path_factory.mktemp("zarr")
@@ -250,7 +260,7 @@ def _test_monthly_values(
expected_n_samples = len(mock_data.obs_times) - 1
assert len(dataset) == expected_n_samples
- arrays, times = dataset[global_idx]
+ arrays, time = dataset[global_idx]
with xr.open_mfdataset(
mock_data.tmpdir.glob(file_pattern),
engine=engine,
@@ -259,7 +269,7 @@ def _test_monthly_values(
coords="minimal",
) as ds:
target_times = ds["time"][global_idx : global_idx + 2].drop_vars("time")
- xr.testing.assert_equal(times, target_times)
+ xr.testing.assert_equal(time, target_times)
lon_dim, lat_dim = infer_horizontal_dimension_names(ds)
dims = ("time", str(lat_dim), str(lon_dim))
shape = (2, ds.sizes[lat_dim], ds.sizes[lon_dim])
@@ -309,17 +319,13 @@ def test_XarrayDataset_monthly_n_timesteps(mock_monthly_netcdfs, n_samples):
mock_data: MockData = mock_monthly_netcdfs
if len(mock_data.var_names.initial_condition_names) != 0:
return
- config = DataLoaderConfig(
- [XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(stop=n_samples))],
- 1,
- 0,
- )
+ config = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(stop=n_samples))
n_forward_steps = 4
requirements = DataRequirements(
names=mock_data.var_names.all_names + ["x"],
n_timesteps=n_forward_steps + 1,
)
- dataset = get_dataset(config.dataset, requirements)
+ dataset, _ = get_xarray_dataset(config, requirements)
if n_samples is None:
assert len(dataset) == len(mock_data.obs_times) - n_forward_steps
else:
@@ -387,25 +393,36 @@ def test_XarrayDataset_yearly(mock_yearly_netcdfs, global_idx):
target_times = ds["time"][global_idx : global_idx + n_steps].drop_vars(
"time"
)
- data, times = dataset[global_idx]
- data = data[var_name]
- assert data.shape[0] == n_steps
- assert torch.equal(data, target_data)
- xr.testing.assert_equal(times, target_times)
+ data, time = dataset[global_idx]
+ data_tensor = data[var_name]
+ assert data_tensor.shape[0] == n_steps
+ assert torch.equal(data_tensor, target_data)
+ xr.testing.assert_equal(time, target_times)
+
+
+def test_dataset_dtype_casting(mock_monthly_netcdfs):
+ mock_data: MockData = mock_monthly_netcdfs
+ config = XarrayDataConfig(data_path=mock_data.tmpdir, dtype="bfloat16")
+ requirements = DataRequirements(names=mock_data.var_names.all_names, n_timesteps=2)
+ dataset = XarrayDataset(config=config, requirements=requirements)
+ assert isinstance(dataset.horizontal_coordinates, LatLonCoordinates)
+ assert dataset.horizontal_coordinates.lat.dtype == torch.bfloat16
+ assert dataset.horizontal_coordinates.lon.dtype == torch.bfloat16
+ assert dataset.vertical_coordinate.ak.dtype == torch.bfloat16
+ assert dataset.vertical_coordinate.bk.dtype == torch.bfloat16
+ data, _ = dataset[0]
+ for tensor in data.values():
+ assert tensor.dtype == torch.bfloat16
def test_time_invariant_variable_is_repeated(mock_monthly_netcdfs):
mock_data: MockData = mock_monthly_netcdfs
- config = DataLoaderConfig(
- [XarrayDataConfig(data_path=mock_data.tmpdir)],
- batch_size=1,
- num_data_workers=0,
- )
+ config = XarrayDataConfig(data_path=mock_data.tmpdir)
requirements = DataRequirements(names=mock_data.var_names.all_names, n_timesteps=15)
- data = get_data_loader(config=config, train=False, requirements=requirements)
- batch, _ = data.loader.dataset[0]
- assert batch["constant_var"].shape[0] == 15
- assert batch["constant_scalar_var"].shape == (15, 4, 8)
+ dataset = XarrayDataset(config=config, requirements=requirements)
+ data = dataset[0][0]
+ assert data["constant_var"].shape[0] == 15
+ assert data["constant_scalar_var"].shape == (15, 4, 8)
def _get_repeat_dataset(
@@ -479,7 +496,7 @@ def test_time_index(mock_monthly_netcdfs):
)
last_sample_init_time = len(mock_monthly_netcdfs.obs_times) - n_timesteps + 1
obs_times = mock_monthly_netcdfs.obs_times[:last_sample_init_time]
- assert dataset.sample_start_times.equals(xr.CFTimeIndex(obs_times))
+ assert dataset.sample_start_time.equals(xr.CFTimeIndex(obs_times))
@pytest.mark.parametrize("infer_timestep", [True, False])
@@ -559,7 +576,7 @@ def test_repeat_and_increment_times(n_repeats):
raw_periods = [periods_a, periods_b]
raw_total_periods = sum(raw_periods)
- result = repeat_and_increment_times(raw_times, n_repeats, delta)
+ result = repeat_and_increment_time(raw_times, n_repeats, delta)
full_periods = [len(times) for times in result]
full_total_periods = sum(full_periods)
@@ -573,12 +590,130 @@ def test_repeat_and_increment_times(n_repeats):
np.testing.assert_equal(result_concatenated, expected_concatenated)
-def test_available_times(mock_monthly_netcdfs):
+@pytest.mark.parametrize("n_repeats", [1, 3])
+def test_all_times(mock_monthly_netcdfs, n_repeats):
+ n_timesteps = 2 # Arbitrary for this test
+ dataset = _get_repeat_dataset(mock_monthly_netcdfs, n_timesteps, n_repeats)
+ expected_periods = n_repeats * len(mock_monthly_netcdfs.obs_times)
+ expected = xr.cftime_range(
+ MOCK_DATA_START_DATE, periods=expected_periods, freq=MOCK_DATA_FREQ
+ )
+ result = dataset.all_times
+ assert result.equals(expected)
+
+
+def test_get_sample_by_time_slice_times_n_repeats(mock_monthly_netcdfs: MockData):
+ n_timesteps = 2 # Arbitrary for this test
+ n_repeats = 3
+ repeated_dataset = _get_repeat_dataset(mock_monthly_netcdfs, n_timesteps, n_repeats)
+
+ # Pick a slice that is outside the range of the unrepeated data
+ unrepeated_length = len(repeated_dataset.all_times) // n_repeats
+ time_slice = slice(unrepeated_length, unrepeated_length + 3)
+
+ _, result = repeated_dataset.get_sample_by_time_slice(time_slice)
+ expected = xr.DataArray(
+ repeated_dataset.all_times[time_slice].values, dims=["time"]
+ )
+ xr.testing.assert_equal(result, expected)
+
+
+@pytest.mark.parametrize(
+ "dtype,expected_torch_dtype", [("int16", torch.int16), (None, None)]
+)
+def test_dataset_config_dtype(dtype, expected_torch_dtype):
+ config = XarrayDataConfig(data_path="path/to/data", dtype=dtype)
+ assert config.torch_dtype == expected_torch_dtype
+
+
+def test_dataset_config_dtype_raises():
+ with pytest.raises(ValueError):
+ XarrayDataConfig(data_path="path/to/data", dtype="invalid_dtype")
+
+
+def test_renaming(mock_monthly_netcdfs):
+ # stepper in/out names should be variables after renaming
+ stepper_variables = ["foo", "bar_new"]
+ config = XarrayDataConfig(
+ data_path=mock_monthly_netcdfs.tmpdir, renamed_variables={"bar": "bar_new"}
+ )
+ dataset = XarrayDataset(
+ config,
+ DataRequirements(names=stepper_variables, n_timesteps=2),
+ )
+ data, _ = dataset[0]
+ assert "bar_new" in data
+ assert "bar" not in data
+
+
+def test_fill_nans(mock_monthly_netcdfs_with_nans):
+ nan_config = FillNaNsConfig()
+ config = XarrayDataConfig(
+ data_path=mock_monthly_netcdfs_with_nans.tmpdir, fill_nans=nan_config
+ )
+ requirements = DataRequirements(
+ names=mock_monthly_netcdfs_with_nans.var_names.all_names, n_timesteps=2
+ )
+ dataset = XarrayDataset(config, requirements)
+ data, _ = dataset[0]
+ assert torch.all(data["foo"][0, :, 0] == 0)
+
+
+def test_keep_nans(mock_monthly_netcdfs_with_nans):
+ config_keep_nan = XarrayDataConfig(data_path=mock_monthly_netcdfs_with_nans.tmpdir)
+ requirements = DataRequirements(
+ names=mock_monthly_netcdfs_with_nans.var_names.all_names, n_timesteps=2
+ )
+ dataset = XarrayDataset(config_keep_nan, requirements)
+ data_with_nan, _ = dataset[0]
+ assert torch.all(torch.isnan(data_with_nan["foo"][0, :, 0]))
+
+
+def test_overwrite(mock_monthly_netcdfs):
+ const = -10
+ multiple = 3.5
+
+ overwrite_config = OverwriteConfig(
+ constant={"foo": const},
+ multiply_scalar={"bar": multiple},
+ )
+
config = XarrayDataConfig(data_path=mock_monthly_netcdfs.tmpdir)
+ n_timesteps = 2
dataset = XarrayDataset(
config,
DataRequirements(
- names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=10
+ names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=n_timesteps
),
+ )[0][0]
+
+ config_overwrite = XarrayDataConfig(
+ data_path=mock_monthly_netcdfs.tmpdir, overwrite=overwrite_config
)
- assert dataset.all_times.equals(xr.CFTimeIndex(mock_monthly_netcdfs.obs_times))
+ n_timesteps = 2
+ dataset_overwrite = XarrayDataset(
+ config_overwrite,
+ DataRequirements(
+ names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=n_timesteps
+ ),
+ )[0][0]
+
+ for v in ["foo", "bar"]:
+ assert dataset_overwrite[v].dtype == dataset[v].dtype
+ assert dataset_overwrite[v].device == dataset[v].device
+ assert torch.equal(
+ dataset_overwrite["foo"], torch.ones_like(dataset["foo"]) * const
+ )
+ assert torch.equal(dataset_overwrite["bar"], dataset["bar"] * multiple)
+
+
+def test_overwrite_raises_error_on_original_name(mock_monthly_netcdfs):
+ overwrite_config = OverwriteConfig(
+ constant={"foo": 3},
+ )
+ with pytest.raises(ValueError):
+ XarrayDataConfig(
+ data_path=mock_monthly_netcdfs.tmpdir,
+ overwrite=overwrite_config,
+ renamed_variables={"foo": "foo_new"},
+ )
diff --git a/fme/fme/core/data_loading/utils.py b/fme/fme/core/dataset/utils.py
similarity index 65%
rename from fme/fme/core/data_loading/utils.py
rename to fme/fme/core/dataset/utils.py
index 05cb8a5..5146b52 100644
--- a/fme/fme/core/data_loading/utils.py
+++ b/fme/fme/core/dataset/utils.py
@@ -1,23 +1,15 @@
-import dataclasses
import datetime
import warnings
from typing import Hashable, List, Optional, Sequence, Tuple
-import cftime
-import numpy as np
import torch
import xarray as xr
-from torch.utils.data import default_collate
-
-from fme.core.typing_ import TensorMapping
-
-from .data_typing import HorizontalCoordinates
SLICE_NONE = slice(None)
-def infer_horizontal_dimension_names(ds: xr.Dataset) -> List[Hashable]:
- hdims: List[Hashable]
+def infer_horizontal_dimension_names(ds: xr.Dataset) -> List[str]:
+ hdims: List[str]
if "grid_xt" in ds.variables:
hdims = ["grid_xt", "grid_yt"]
elif "lon" in ds.variables:
@@ -112,10 +104,10 @@ def load_series_data(
ds: xr.Dataset,
names: List[str],
time_dim: Hashable,
- spatial_dims: HorizontalCoordinates,
+ spatial_dim_names: List[str],
):
time_slice = slice(idx, idx + n_steps)
- dims = [time_dim] + spatial_dims.dims
+ dims = [time_dim] + spatial_dim_names
shape = [n_steps] + [ds.sizes[spatial_dim] for spatial_dim in dims[1:]]
loaded = _load_all_variables(ds, names, time_slice)
arrays = {}
@@ -125,70 +117,21 @@ def load_series_data(
return arrays
-def get_horizontal_dimensions(ds: xr.Dataset) -> List[np.ndarray]:
+def get_horizontal_dimensions(
+ ds: xr.Dataset, dtype: Optional[torch.dtype]
+) -> List[torch.Tensor]:
hdims = infer_horizontal_dimension_names(ds)
horizontal_values = []
for dim in hdims:
if dim in ds:
- horizontal_values.append(np.array(ds[dim].values, dtype=np.float32))
+ horizontal_values.append(torch.tensor(ds[dim].values, dtype=dtype))
else:
raise ValueError(f"Expected {dim} in dataset: {ds}.")
return horizontal_values
-def get_times(ds: xr.Dataset, start: int, n_steps: int) -> xr.DataArray:
- """
- Get the time coordinate segment from the dataset, check that it's a
- cftime.datetime object, and return it is a data array (not a coordinate),
- so that it can be concatenated with other samples' times.
- """
- time_segment = ds["time"][slice(start, start + n_steps)]
- assert isinstance(
- time_segment[0].item(), cftime.datetime
- ), "time must be cftime.datetime."
- if len(time_segment) != n_steps:
- raise ValueError(
- f"Expected {n_steps} time steps, but got {len(time_segment)} instead."
- )
- return time_segment.drop_vars(["time"])
-
-
-@dataclasses.dataclass
-class BatchData:
- """A container for the data and time coordinates of a batch.
-
- Attributes:
- data: Data for each variable in each sample, concatenated along samples
- to make a batch. To be used directly in training, validation, and
- inference.
- times: An array of times for each sample in the batch, concatenated along
- samples to make a batch. To be used in writing out inference
- predictions with time coordinates, not directly in ML.
-
- """
-
- data: TensorMapping
- times: xr.DataArray
-
- @classmethod
- def from_sample_tuples(
- cls,
- samples: Sequence[Tuple[TensorMapping, xr.DataArray]],
- sample_dim_name: str = "sample",
- ) -> "BatchData":
- """
- Collate function for use with PyTorch DataLoader. Needed since samples contain
- both tensor mapping and xarray time coordinates, the latter of which we do
- not want to convert to tensors.
- """
- sample_data, sample_times = zip(*samples)
- batch_data = default_collate(sample_data)
- batch_times = xr.concat(sample_times, dim=sample_dim_name)
- return cls(batch_data, batch_times)
-
-
def decode_timestep(microseconds: int) -> datetime.timedelta:
return datetime.timedelta(microseconds=microseconds)
diff --git a/fme/fme/core/data_loading/_xarray.py b/fme/fme/core/dataset/xarray.py
similarity index 60%
rename from fme/fme/core/data_loading/_xarray.py
rename to fme/fme/core/dataset/xarray.py
index 6d6a85e..82ef2a6 100644
--- a/fme/fme/core/data_loading/_xarray.py
+++ b/fme/fme/core/dataset/xarray.py
@@ -13,24 +13,21 @@
import torch
import xarray as xr
-import fme
-from fme.core import metrics
-from fme.core.typing_ import TensorDict
-
-from .config import Slice, TimeSlice, XarrayDataConfig
-from .data_typing import (
- Dataset,
+from fme.core.coordinates import (
HEALPixCoordinates,
HorizontalCoordinates,
+ HybridSigmaPressureCoordinate,
LatLonCoordinates,
- SigmaCoordinates,
- VariableMetadata,
)
+from fme.core.device import get_device
+from fme.core.typing_ import Slice, TensorDict
+
+from .config import RepeatedInterval, TimeSlice, XarrayDataConfig
+from .data_typing import Dataset, VariableMetadata
from .requirements import DataRequirements
from .utils import (
as_broadcasted_tensor,
get_horizontal_dimensions,
- get_times,
infer_horizontal_dimension_names,
load_series_data,
)
@@ -48,30 +45,19 @@
)
-def subset_dataset(dataset: Dataset, subset: slice) -> Dataset:
- """Returns a subset of the dataset and propagates other properties."""
- indices = range(len(dataset))[subset]
- logging.info(f"Subsetting dataset samples according to {subset}.")
- subsetted_dataset = torch.utils.data.Subset(dataset, indices)
- subsetted_dataset.metadata = dataset.metadata
- subsetted_dataset.area_weights = dataset.area_weights
- subsetted_dataset.sigma_coordinates = dataset.sigma_coordinates
- subsetted_dataset.horizontal_coordinates = dataset.horizontal_coordinates
- subsetted_dataset.timestep = dataset.timestep
- subsetted_dataset.is_remote = dataset.is_remote
- subsetted_dataset.sample_start_times = dataset.sample_start_times[subset]
- return subsetted_dataset
-
-
-def get_sigma_coordinates(ds: xr.Dataset) -> SigmaCoordinates:
+def _get_vertical_coordinates(
+ ds: xr.Dataset, dtype: Optional[torch.dtype]
+) -> HybridSigmaPressureCoordinate:
"""
- Get sigma coordinates from a dataset.
+ Get hybrid sigma-pressure coordinates from a dataset.
Assumes that the dataset contains variables named `ak_N` and `bk_N` where
`N` is the level number. The returned tensors are sorted by level number.
Args:
- ds: Dataset to get sigma coordinates from.
+ ds: Dataset to get vertical coordinates from.
+ dtype: Data type of the returned tensors. If None, the dtype is not
+ changed from the original in ds.
"""
ak_mapping = {
int(v[3:]): torch.as_tensor(ds[v].values)
@@ -88,20 +74,14 @@ def get_sigma_coordinates(ds: xr.Dataset) -> SigmaCoordinates:
if len(ak_list) == 0 or len(bk_list) == 0:
logger.warning("Dataset does not contain ak and bk coordinates.")
- return SigmaCoordinates(
- ak=torch.tensor([], device=fme.get_device()),
- bk=torch.tensor([], device=fme.get_device()),
+ return HybridSigmaPressureCoordinate(
+ ak=torch.tensor([]),
+ bk=torch.tensor([]),
)
- if len(ak_list) != len(bk_list):
- raise ValueError(
- "Expected same number of ak and bk coordinates, "
- f"got {len(ak_list)} and {len(bk_list)}."
- )
-
- return SigmaCoordinates(
- ak=torch.as_tensor(ak_list, device=fme.get_device(), dtype=torch.float),
- bk=torch.as_tensor(bk_list, device=fme.get_device(), dtype=torch.float),
+ return HybridSigmaPressureCoordinate(
+ ak=torch.as_tensor(ak_list, dtype=dtype),
+ bk=torch.as_tensor(bk_list, dtype=dtype),
)
@@ -113,26 +93,26 @@ def get_raw_times(paths: List[str], engine: str) -> List[np.ndarray]:
return times
-def repeat_and_increment_times(
+def repeat_and_increment_time(
raw_times: List[np.ndarray], n_repeats: int, timestep: datetime.timedelta
) -> List[np.ndarray]:
"""Repeats and increments a collection of arrays of evenly spaced times."""
n_timesteps = sum(len(times) for times in raw_times)
timespan = timestep * n_timesteps
- repeated_and_incremented_times = []
+ repeated_and_incremented_time = []
for repeats in range(n_repeats):
increment = repeats * timespan
- for times in raw_times:
- incremented_times = times + increment
- repeated_and_incremented_times.append(incremented_times)
- return repeated_and_incremented_times
+ for time in raw_times:
+ incremented_time = time + increment
+ repeated_and_incremented_time.append(incremented_time)
+ return repeated_and_incremented_time
-def get_cumulative_timesteps(times: List[np.ndarray]) -> np.ndarray:
- """Returns a list of cumulative timesteps for each item in times."""
+def get_cumulative_timesteps(time: List[np.ndarray]) -> np.ndarray:
+ """Returns a list of cumulative timesteps for each item in a time coordinate."""
num_timesteps_per_file = [0]
- for time_coord in times:
+ for time_coord in time:
num_timesteps_per_file.append(len(time_coord))
return np.array(num_timesteps_per_file).cumsum()
@@ -261,18 +241,65 @@ def _open_file_fh_cached(path, **kwargs):
protocol = _get_protocol(path)
if protocol:
# add an LRU cache for remote zarrs
- fn = _open_xr_dataset_lru
- else:
- # netcdf4 and h5engine have a filehandle LRU cache in xarray
- # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501
- fn = _open_xr_dataset
-
- return fn(
+ return _open_xr_dataset_lru(
+ path,
+ **kwargs,
+ )
+ # netcdf4 and h5engine have a filehandle LRU cache in xarray
+ # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501
+ return _open_xr_dataset(
path,
**kwargs,
)
+class DatasetProperties:
+ def __init__(
+ self,
+ variable_metadata: Mapping[str, VariableMetadata],
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ horizontal_coordinates: HorizontalCoordinates,
+ timestep: datetime.timedelta,
+ is_remote: bool,
+ ):
+ self.variable_metadata = variable_metadata
+ self.vertical_coordinate = vertical_coordinate
+ self.horizontal_coordinates = horizontal_coordinates
+ self.timestep = timestep
+ self.is_remote = is_remote
+
+ def to_device(self) -> "DatasetProperties":
+ device = get_device()
+ return DatasetProperties(
+ self.variable_metadata,
+ self.vertical_coordinate.to(device),
+ self.horizontal_coordinates.to(device),
+ self.timestep,
+ self.is_remote,
+ )
+
+ def update(self, other: "DatasetProperties"):
+ self.is_remote = self.is_remote or other.is_remote
+ if self.timestep != other.timestep:
+ raise ValueError("Inconsistent timesteps between datasets")
+ if self.variable_metadata != other.variable_metadata:
+ raise ValueError("Inconsistent metadata between datasets")
+ if self.vertical_coordinate != other.vertical_coordinate:
+ raise ValueError("Inconsistent vertical coordinates between datasets")
+ if self.horizontal_coordinates != other.horizontal_coordinates:
+ raise ValueError("Inconsistent horizontal coordinates between datasets")
+
+
+def get_xarray_dataset(
+ config: XarrayDataConfig, requirements: DataRequirements
+) -> Tuple["Dataset", DatasetProperties]:
+ dataset = XarrayDataset(config, requirements)
+ properties = dataset.properties
+ index_slice = as_index_selection(config.subset, dataset)
+ dataset = dataset.subset(index_slice)
+ return dataset, properties
+
+
class XarrayDataset(Dataset):
"""Load data from a directory of files matching a pattern using xarray. The
number of contiguous timesteps to load for each sample is specified by
@@ -280,7 +307,8 @@ class XarrayDataset(Dataset):
For example, if the file(s) have the time coordinate
(t0, t1, t2, t3, t4) and requirements.n_timesteps=3, then this dataset will
- provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4)."""
+ provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4).
+ """
def __init__(
self,
@@ -288,17 +316,20 @@ def __init__(
requirements: DataRequirements,
):
self._horizontal_coordinates: HorizontalCoordinates
- self.names = requirements.names
+ self.renamed_variables = config.renamed_variables or {}
+ self._names = self._get_names_to_load(requirements.names)
self.path = config.data_path
self.file_pattern = config.file_pattern
- self.engine = "netcdf4" if config.engine is None else config.engine
- self._default_file_pattern_check()
+ self.engine = config.engine
+ self.dtype = config.torch_dtype
+ self.spatial_dimensions = config.spatial_dimensions
+ self.fill_nans = config.fill_nans
fs = _get_fs(self.path)
glob_paths = sorted(fs.glob(os.path.join(self.path, config.file_pattern)))
self._raw_paths = _preserve_protocol(self.path, glob_paths)
if len(self._raw_paths) == 0:
raise ValueError(
- f"No files found matching '{self.path}/{config.file_pattern}'."
+ f"No files found matching '{self.path}/{self.file_pattern}'."
)
self.full_paths = self._raw_paths * config.n_repeats
self.n_steps = requirements.n_timesteps # one input, n_steps - 1 outputs
@@ -310,15 +341,34 @@ def __init__(
)
(
self._horizontal_coordinates,
- self._area_weights,
self._static_derived_data,
- ) = self.configure_horizontal_coordinates(config, first_dataset)
+ ) = self.configure_horizontal_coordinates(first_dataset)
(
- self.time_dependent_names,
- self.time_invariant_names,
- self.static_derived_names,
+ self._time_dependent_names,
+ self._time_invariant_names,
+ self._static_derived_names,
) = self._group_variable_names_by_time_type()
- self._sigma_coordinates = get_sigma_coordinates(first_dataset)
+ self._vertical_coordinates = _get_vertical_coordinates(
+ first_dataset, self.dtype
+ )
+ self.overwrite = config.overwrite
+
+ @property
+ def properties(self) -> DatasetProperties:
+ return DatasetProperties(
+ self._variable_metadata,
+ self._vertical_coordinates,
+ self._horizontal_coordinates,
+ self.timestep,
+ self.is_remote,
+ )
+
+ def _get_names_to_load(self, names: List[str]) -> List[str]:
+ # requirements.names from stepper config refer to the final set of
+ # variables after any renaming occurs. This returns the set of names
+ # to load from data before renaming.
+ inverted_renaming = {v: k for k, v in self.renamed_variables.items()}
+ return [inverted_renaming.get(n, n) for n in names]
@property
def horizontal_coordinates(self) -> HorizontalCoordinates:
@@ -333,20 +383,12 @@ def is_remote(self) -> bool:
@property
def all_times(self) -> xr.CFTimeIndex:
- """Time index of all available times in the data"""
+ """Time index of all available times in the data."""
return self._all_times
- def _default_file_pattern_check(self):
- if self.engine == "zarr" and self.file_pattern == "*.nc":
- raise ValueError(
- "The file pattern is set to the default NetCDF file pattern *.nc "
- "but the engine is specified as 'zarr'. Please set "
- "`XarrayDataConfig.file_pattern` to match the zarr filename."
- )
-
- def _get_metadata(self, ds):
+ def _get_variable_metadata(self, ds):
result = {}
- for name in self.names:
+ for name in self._names:
if name in StaticDerivedData.names:
result[name] = StaticDerivedData.metadata[name]
elif hasattr(ds[name], "units") and hasattr(ds[name], "long_name"):
@@ -354,7 +396,7 @@ def _get_metadata(self, ds):
units=ds[name].units,
long_name=ds[name].long_name,
)
- self._metadata = result
+ self._variable_metadata = result
def _get_files_stats(self, n_repeats: int, infer_timestep: bool):
logging.info(f"Opening data at {os.path.join(self.path, self.file_pattern)}")
@@ -363,44 +405,32 @@ def _get_files_stats(self, n_repeats: int, infer_timestep: bool):
self._timestep: Optional[datetime.timedelta]
if infer_timestep:
self._timestep = get_timestep(np.concatenate(raw_times))
- time_coords = repeat_and_increment_times(
- raw_times, n_repeats, self.timestep
- )
+ time_coord = repeat_and_increment_time(raw_times, n_repeats, self.timestep)
else:
self._timestep = None
- time_coords = raw_times
+ time_coord = raw_times
- cum_num_timesteps = get_cumulative_timesteps(time_coords)
+ cum_num_timesteps = get_cumulative_timesteps(time_coord)
self.start_indices = cum_num_timesteps[:-1]
self.total_timesteps = cum_num_timesteps[-1]
self._n_initial_conditions = self.total_timesteps - self.n_steps + 1
- self._sample_start_times = xr.CFTimeIndex(
- np.concatenate(time_coords)[: self._n_initial_conditions]
+ self._sample_start_time = xr.CFTimeIndex(
+ np.concatenate(time_coord)[: self._n_initial_conditions]
)
- self._all_times = xr.CFTimeIndex(np.concatenate(raw_times))
+ self._all_times = xr.CFTimeIndex(np.concatenate(time_coord))
- del cum_num_timesteps, time_coords
+ del cum_num_timesteps, time_coord
ds = self._open_file(0)
- self._get_metadata(ds)
+ self._get_variable_metadata(ds)
- for i in range(len(self.names)):
- if self.names[i] in ds.variables:
- img_shape = ds[self.names[i]].shape[-2:]
- break
- else:
- raise ValueError(
- f"None of the requested variables {self.names} are present "
- f"in the dataset."
- )
logging.info(f"Found {self._n_initial_conditions} samples.")
- logging.info(f"Image shape is {img_shape[0]} x {img_shape[1]}.")
- logging.info(f"Following variables are available: {list(ds.variables)}.")
def _group_variable_names_by_time_type(self) -> VariableNames:
"""Returns lists of time-dependent variable names, time-independent
variable names, and variables which are only present as an initial
- condition."""
+ condition.
+ """
(
time_dependent_names,
time_invariant_names,
@@ -410,72 +440,81 @@ def _group_variable_names_by_time_type(self) -> VariableNames:
# fields a time dimension. We assume that all fields are present in the
# netcdf file corresponding to the first chunk of time.
with _open_xr_dataset(self.full_paths[0], engine=self.engine) as ds:
- for name in self.names:
+ for name in self._names:
if name in StaticDerivedData.names:
static_derived_names.append(name)
else:
- dims = ds[name].dims
- if "time" in dims:
- time_dependent_names.append(name)
+ try:
+ da = ds[name]
+ except KeyError:
+ raise ValueError(
+ f"Required variable not found in dataset: {name}."
+ )
else:
- time_invariant_names.append(name)
+ dims = da.dims
+ if "time" in dims:
+ time_dependent_names.append(name)
+ else:
+ time_invariant_names.append(name)
+ logging.info(
+ f"The required variables have been found in the dataset: {self._names}."
+ )
+
return VariableNames(
time_dependent_names,
time_invariant_names,
static_derived_names,
)
- def configure_horizontal_coordinates(self, config, first_dataset):
+ def configure_horizontal_coordinates(
+ self, first_dataset
+ ) -> Tuple[HorizontalCoordinates, StaticDerivedData]:
horizontal_coordinates: HorizontalCoordinates
- area_weights: torch.Tensor
static_derived_data: StaticDerivedData
- dims = get_horizontal_dimensions(first_dataset)
+ dims = get_horizontal_dimensions(first_dataset, self.dtype)
- if config.spatial_dimensions == "latlon":
+ if self.spatial_dimensions == "latlon":
lons = dims[0]
lats = dims[1]
names = infer_horizontal_dimension_names(first_dataset)
lon_name = names[0]
lat_name = names[1]
horizontal_coordinates = LatLonCoordinates(
- lon=torch.as_tensor(lons, device=fme.get_device()),
- lat=torch.as_tensor(lats, device=fme.get_device()),
- lat_name=lat_name,
- lon_name=lon_name,
+ lon=lons,
+ lat=lats,
+ loaded_lat_name=lat_name,
+ loaded_lon_name=lon_name,
)
- area_weights = metrics.spherical_area_weights(lats, len(lons))
static_derived_data = StaticDerivedData(horizontal_coordinates)
- elif config.spatial_dimensions == "healpix":
+ elif self.spatial_dimensions == "healpix":
face = dims[0]
height = dims[1]
width = dims[2]
horizontal_coordinates = HEALPixCoordinates(
- face=torch.as_tensor(face, device=fme.get_device()),
- height=torch.as_tensor(height, device=fme.get_device()),
- width=torch.as_tensor(width, device=fme.get_device()),
+ face=face,
+ height=height,
+ width=width,
)
- # Area weights should be all 1's of shape (face, height, width),
- # since area is uniform.
- area_weights = torch.ones((len(face), len(height), len(width)))
static_derived_data = StaticDerivedData(horizontal_coordinates)
else:
raise ValueError(
- f"unexpected config.spatial_dimensions {config.spatial_dimensions},"
+ f"unexpected config.spatial_dimensions {self.spatial_dimensions},"
" should be one of 'latlon' or 'healpix'"
)
- return horizontal_coordinates, area_weights, static_derived_data
+ coords_sizes = {
+ coord_name: len(coord)
+ for coord_name, coord in horizontal_coordinates.coords.items()
+ }
+ logging.info(f"Horizontal coordinate sizes are {coords_sizes}.")
+ return horizontal_coordinates, static_derived_data
@property
- def area_weights(self) -> torch.Tensor:
- return self._area_weights
+ def variable_metadata(self) -> Mapping[str, VariableMetadata]:
+ return self._variable_metadata
@property
- def metadata(self) -> Mapping[str, VariableMetadata]:
- return self._metadata
-
- @property
- def sigma_coordinates(self) -> SigmaCoordinates:
- return self._sigma_coordinates
+ def vertical_coordinate(self) -> HybridSigmaPressureCoordinate:
+ return self._vertical_coordinates
@property
def timestep(self) -> datetime.timedelta:
@@ -496,9 +535,9 @@ def _open_file(self, idx):
return _open_file_fh_cached(self.full_paths[idx], engine=self.engine)
@property
- def sample_start_times(self) -> xr.CFTimeIndex:
+ def sample_start_time(self) -> xr.CFTimeIndex:
"""Return cftime index corresponding to start time of each sample."""
- return self._sample_start_times
+ return self._sample_start_time
def __getitem__(self, idx: int) -> Tuple[TensorDict, xr.DataArray]:
"""Return a sample of data spanning the timesteps [idx, idx + self.n_steps).
@@ -525,26 +564,26 @@ def get_sample_by_time_slice(
# get the sequence of observations
arrays: Dict[str, List[torch.Tensor]] = {}
- times_segments: List[xr.DataArray] = []
idxs = range(input_file_idx, output_file_idx + 1)
total_steps = 0
for i, file_idx in enumerate(idxs):
ds = self._open_file(file_idx)
+ if self.fill_nans is not None:
+ ds = ds.fillna(self.fill_nans.value)
start = input_local_idx if i == 0 else 0
stop = output_local_idx if i == len(idxs) - 1 else len(ds["time"]) - 1
n_steps = stop - start + 1
total_steps += n_steps
tensor_dict = load_series_data(
- start,
- n_steps,
- ds,
- self.time_dependent_names,
- "time",
- self._horizontal_coordinates,
+ idx=start,
+ n_steps=n_steps,
+ ds=ds,
+ names=self._time_dependent_names,
+ time_dim="time",
+ spatial_dim_names=self._horizontal_coordinates.loaded_dims,
)
- for n in self.time_dependent_names:
+ for n in self._time_dependent_names:
arrays.setdefault(n, []).append(tensor_dict[n])
- times_segments.append(get_times(ds, start, n_steps))
ds.close()
del ds
@@ -552,49 +591,77 @@ def get_sample_by_time_slice(
for n, tensor_list in arrays.items():
tensors[n] = torch.cat(tensor_list)
del arrays
- times: xr.DataArray = xr.concat(times_segments, dim="time")
# load time-invariant variables from first dataset
- if len(self.time_invariant_names) > 0:
+ if len(self._time_invariant_names) > 0:
ds = self._open_file(idxs[0])
- dims = ["time"] + self._horizontal_coordinates.dims
+ dims = ["time"] + self._horizontal_coordinates.loaded_dims
shape = [total_steps] + [ds.sizes[dim] for dim in dims[1:]]
- for name in self.time_invariant_names:
+ for name in self._time_invariant_names:
variable = ds[name].variable
tensors[name] = as_broadcasted_tensor(variable, dims, shape)
ds.close()
del ds
# load static derived variables
- for name in self.static_derived_names:
+ for name in self._static_derived_names:
tensor = self._static_derived_data[name]
tensors[name] = tensor.repeat((total_steps, 1, 1))
- return tensors, times
+ # cast to desired dtype
+ tensors = {k: v.to(dtype=self.dtype) for k, v in tensors.items()}
+
+ # apply renaming
+ for original_name, new_name in self.renamed_variables.items():
+ tensors[new_name] = tensors.pop(original_name)
+ # Apply field overwrites
+ tensors = self.overwrite.apply(tensors)
-def as_index_slice(subset: Union[Slice, TimeSlice], dataset: XarrayDataset) -> slice:
+ # Create a DataArray of times to return corresponding to the slice that
+ # is valid even when n_repeats > 1.
+ time = xr.DataArray(self.all_times[time_slice].values, dims=["time"])
+
+ return tensors, time
+
+ def subset(self, subset: Union[slice, torch.Tensor]) -> Dataset:
+ """Returns a subset of the dataset and propagates other properties."""
+ indices = range(len(self))[subset]
+ logging.info(f"Subsetting dataset samples according to {subset}.")
+ subsetted_dataset = torch.utils.data.Subset(self, indices)
+ return subsetted_dataset
+
+
+def as_index_selection(
+ subset: Union[Slice, TimeSlice, RepeatedInterval], dataset: XarrayDataset
+) -> Union[slice, np.ndarray]:
"""Converts a subset defined either as a Slice or TimeSlice into an index slice
- based on time coordinate in provided dataset."""
+ based on time coordinate in provided dataset.
+ """
if isinstance(subset, Slice):
- index_slice = subset.slice
+ index_selection = subset.slice
elif isinstance(subset, TimeSlice):
- index_slice = subset.slice(dataset.sample_start_times)
+ index_selection = subset.slice(dataset.sample_start_time)
+ elif isinstance(subset, RepeatedInterval):
+ try:
+ index_selection = subset.get_boolean_mask(len(dataset), dataset.timestep)
+ except ValueError as e:
+ raise ValueError(f"Error when applying RepeatedInterval to dataset: {e}")
else:
raise TypeError(f"subset must be Slice or TimeSlice, got {type(subset)}")
- return index_slice
+ return index_selection
-def get_timestep(times: np.ndarray) -> datetime.timedelta:
- """Computes the timestep of an array of times.
+def get_timestep(time: np.ndarray) -> datetime.timedelta:
+ """Computes the timestep of an array of a time coordinate array.
Raises an error if the times are not separated by a positive constant
interval, or if the array has one or fewer times.
"""
- assert len(times.shape) == 1, "times must be a 1D array"
+ assert len(time.shape) == 1, "times must be a 1D array"
- if len(times) > 1:
- timesteps = np.diff(times)
+ if len(time) > 1:
+ timesteps = np.diff(time)
timestep = timesteps[0]
if not (timestep > datetime.timedelta(days=0)):
diff --git a/fme/fme/core/device.py b/fme/fme/core/device.py
index 0c24044..f0dfab3 100644
--- a/fme/fme/core/device.py
+++ b/fme/fme/core/device.py
@@ -2,6 +2,8 @@
import torch
+from .typing_ import TensorDict, TensorMapping
+
def using_gpu() -> bool:
return get_device().type == "cuda"
@@ -9,12 +11,18 @@ def using_gpu() -> bool:
def get_device() -> torch.device:
"""If CUDA is available, return a CUDA device. Otherwise, return a CPU device
- unless FME_USE_MPS is set, in which case return an MPS device if available."""
+ unless FME_USE_MPS is set, in which case return an MPS device if available.
+ """
if torch.cuda.is_available():
return torch.device("cuda", torch.cuda.current_device())
else:
mps_available = torch.backends.mps.is_available()
if mps_available and os.environ.get("FME_USE_MPS", "0") == "1":
- return torch.device("mps")
+ return torch.device("mps", 0)
else:
return torch.device("cpu")
+
+
+def move_tensordict_to_device(data: TensorMapping) -> TensorDict:
+ device = get_device()
+ return {name: value.to(device) for name, value in data.items()}
diff --git a/fme/fme/core/dicts.py b/fme/fme/core/dicts.py
index 1c2ffda..1b07b78 100644
--- a/fme/fme/core/dicts.py
+++ b/fme/fme/core/dicts.py
@@ -5,9 +5,8 @@ def to_flat_dict(d: Mapping[str, Any]) -> Dict[str, Any]:
"""
Converts any nested dictionaries to a flat version with
the nested keys joined with a '.', e.g., {a: {b: 1}} ->
- {a.b: 1}
+ {a.b: 1}.
"""
-
new_flat = {}
for k, v in d.items():
if isinstance(v, dict):
@@ -23,9 +22,8 @@ def to_flat_dict(d: Mapping[str, Any]) -> Dict[str, Any]:
def to_nested_dict(d: Mapping[str, Any]) -> Dict[str, Any]:
"""
Converts a flat dictionary with '.' joined keys back into
- a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}}
+ a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}}.
"""
-
new_config: Dict[str, Any] = {}
for k, v in d.items():
diff --git a/fme/fme/core/distributed.py b/fme/fme/core/distributed.py
index eb12cb2..5df8151 100644
--- a/fme/fme/core/distributed.py
+++ b/fme/fme/core/distributed.py
@@ -2,6 +2,7 @@
from typing import Callable, List, Optional, Union
import torch.distributed
+from torch.nn import SyncBatchNorm
from torch.nn.functional import pad
from torch.nn.parallel import DistributedDataParallel
@@ -40,7 +41,7 @@ class Distributed:
variables without having to pass them around, and lets us put the initialization
for this global state in the same place as the routines that use it.
- Attributes:
+ Parameters:
world_size: The number of processes in the distributed training job.
rank: The global rank of the current process.
local_rank: The node-local rank of the current process.
@@ -179,7 +180,6 @@ def gather_irregular(
A list of tensors of consistent shape, where the i-th element is the tensor
from the i-th process.
"""
-
return gather_irregular(
tensor,
self.reduce_max,
@@ -212,7 +212,7 @@ def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
device_ids = None
output_device = None
return DistributedDataParallel(
- module,
+ SyncBatchNorm.convert_sync_batchnorm(module),
device_ids=device_ids,
output_device=output_device,
)
@@ -244,7 +244,6 @@ def gather_irregular(
Returns:
A list of tensors, where the i-th element is the tensor from the i-th process.
"""
-
output_tensor_size = []
tensor_size = list(tensor.size())
for dim_len in tensor_size:
@@ -281,7 +280,7 @@ def pad_tensor_at_end(
fill_value: Union[float, int] = 0.0,
):
"""Pad tensor by specified amount at end of each dimension.
- Note that `pad` format is in reverse dimension order
+ Note that `pad` format is in reverse dimension order.
Args:
tensor: The tensor to pad
@@ -311,7 +310,7 @@ def pad_tensor_at_end(
def unpad_tensor_at_end(
tensor: torch.Tensor, dimension_difference: torch.Tensor
) -> torch.Tensor:
- """Remove padding from tensor
+ """Remove padding from tensor.
Args:
tensor: The tensor to remove padding from
diff --git a/fme/fme/core/ema.py b/fme/fme/core/ema.py
index dd1e336..8966b72 100644
--- a/fme/fme/core/ema.py
+++ b/fme/fme/core/ema.py
@@ -1,5 +1,5 @@
"""
-Exponential Moving Average (EMA) module
+Exponential Moving Average (EMA) module.
Copied from https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/ema.py
and modified.
@@ -37,8 +37,7 @@
class HasNamedParameters(Protocol):
def named_parameters(
self, recurse: bool = True
- ) -> Iterator[Tuple[str, nn.Parameter]]:
- ...
+ ) -> Iterator[Tuple[str, nn.Parameter]]: ...
@dataclasses.dataclass
@@ -46,7 +45,7 @@ class EMAConfig:
"""
Configuration for exponential moving average of model weights.
- Attributes:
+ Parameters:
decay: decay rate for the moving average
"""
@@ -198,7 +197,7 @@ def from_state(cls, state, model) -> "EMATracker":
Returns:
The EMA tracker.
"""
- ema = cls(model, state["decay"], state["faster_decay_at_start"])
+ ema = cls(model, float(state["decay"]), state["faster_decay_at_start"])
ema.num_updates = state["num_updates"]
ema._module_name_to_ema_name = state["module_name_to_ema_name"]
return ema
diff --git a/fme/fme/core/generics/aggregator.py b/fme/fme/core/generics/aggregator.py
new file mode 100644
index 0000000..146eefd
--- /dev/null
+++ b/fme/fme/core/generics/aggregator.py
@@ -0,0 +1,53 @@
+import abc
+from typing import Any, Dict, Generic, List, TypeVar
+
+PS = TypeVar("PS", contravariant=True) # prognostic state
+T = TypeVar("T", contravariant=True)
+
+
+class AggregatorABC(abc.ABC, Generic[T]):
+ @abc.abstractmethod
+ def record_batch(self, batch: T) -> None:
+ pass
+
+ @abc.abstractmethod
+ def get_logs(self, label: str) -> Dict[str, float]:
+ pass
+
+
+InferenceLog = Dict[str, Any]
+InferenceLogs = List[InferenceLog]
+
+
+class InferenceAggregatorABC(abc.ABC, Generic[PS, T]):
+ @abc.abstractmethod
+ def record_batch(
+ self,
+ data: T,
+ ) -> InferenceLogs:
+ """
+ Record a batch of data.
+
+ Args:
+ data: Batch of data.
+
+ Returns:
+ Logs for the batch.
+ """
+ pass
+
+ @abc.abstractmethod
+ def record_initial_condition(
+ self,
+ initial_condition: PS,
+ ) -> InferenceLogs:
+ """
+ Record the initial condition.
+
+ May only be recorded once, before any calls to record_batch.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_summary_logs(self) -> InferenceLog:
+ pass
diff --git a/fme/fme/core/generics/data.py b/fme/fme/core/generics/data.py
new file mode 100644
index 0000000..da62d33
--- /dev/null
+++ b/fme/fme/core/generics/data.py
@@ -0,0 +1,68 @@
+import abc
+from typing import Generic, Iterable, Protocol, Sized, TypeVar
+
+T = TypeVar("T", covariant=True)
+
+
+class DataLoader(Protocol, Generic[T], Sized, Iterable[T]):
+ pass
+
+
+PS = TypeVar("PS") # prognostic state
+FD = TypeVar("FD", covariant=True) # forcing data
+
+
+class InferenceDataABC(abc.ABC, Generic[PS, FD]):
+ @property
+ @abc.abstractmethod
+ def initial_condition(self) -> PS: ...
+
+ @property
+ @abc.abstractmethod
+ def loader(self) -> DataLoader[FD]: ...
+
+
+class SimpleInferenceData(InferenceDataABC[PS, FD]):
+ def __init__(
+ self,
+ initial_condition: PS,
+ loader: DataLoader[FD],
+ ):
+ self._initial_condition = initial_condition
+ self._loader = loader
+
+ @property
+ def initial_condition(self) -> PS:
+ return self._initial_condition
+
+ @property
+ def loader(self) -> DataLoader[FD]:
+ return self._loader
+
+
+class GriddedDataABC(abc.ABC, Generic[T]):
+ @property
+ @abc.abstractmethod
+ def loader(self) -> DataLoader[T]: ...
+
+ @property
+ @abc.abstractmethod
+ def n_samples(self) -> int: ...
+
+ @property
+ @abc.abstractmethod
+ def n_batches(self) -> int: ...
+
+ @property
+ @abc.abstractmethod
+ def batch_size(self) -> int: ...
+
+ @abc.abstractmethod
+ def set_epoch(self, epoch: int): ...
+
+ @abc.abstractmethod
+ def log_info(self, name: str):
+ """
+ Report information about the data using logging.info.
+ """
+ ...
diff --git a/fme/fme/core/generics/inference.py b/fme/fme/core/generics/inference.py
new file mode 100644
index 0000000..a8382d3
--- /dev/null
+++ b/fme/fme/core/generics/inference.py
@@ -0,0 +1,139 @@
+import logging
+from typing import Callable, Generic, Iterator, Optional, Protocol, Tuple, TypeVar
+
+import torch
+
+from fme.core.generics.aggregator import InferenceAggregatorABC, InferenceLogs
+from fme.core.generics.data import InferenceDataABC
+from fme.core.generics.writer import NullDataWriter, WriterABC
+from fme.core.timing import GlobalTimer
+from fme.core.wandb import WandB
+
+PS = TypeVar("PS") # prognostic state
+FD = TypeVar("FD", contravariant=True) # forcing data
+SD = TypeVar("SD", covariant=True) # stepped data
+
+
+class PredictFunction(Protocol, Generic[PS, FD, SD]):
+ def __call__(
+ self,
+ initial_condition: PS,
+ forcing: FD,
+ compute_derived_variables: bool = False,
+ ) -> Tuple[SD, PS]: ...
+
+
+class Looper(Generic[PS, FD, SD]):
+ """
+ Class for stepping a model forward arbitarily many times.
+ """
+
+ def __init__(
+ self,
+ predict: PredictFunction[PS, FD, SD],
+ data: InferenceDataABC[PS, FD],
+ ):
+ """
+ Args:
+ predict: The prediction function to use.
+ data: The data to use.
+ """
+ self._predict = predict
+ self._prognostic_state = data.initial_condition
+ self._len = len(data.loader)
+ self._loader = iter(data.loader)
+
+ def __iter__(self) -> Iterator[SD]:
+ return self
+
+ def __len__(self) -> int:
+ return self._len
+
+ def __next__(self) -> SD:
+ """Return predictions for the time period corresponding to the next batch
+ of forcing data. Also returns the forcing data.
+ """
+ timer = GlobalTimer.get_instance()
+ with timer.context("data_loading"):
+ try:
+ forcing_data = next(self._loader)
+ except StopIteration:
+ raise StopIteration
+ output_data, self._prognostic_state = self._predict(
+ self._prognostic_state,
+ forcing=forcing_data,
+ compute_derived_variables=True,
+ )
+ return output_data
+
+ def get_prognostic_state(self) -> PS:
+ return self._prognostic_state
+
+
+def get_record_to_wandb(label: str = "") -> Callable[[InferenceLogs], None]:
+ wandb = WandB.get_instance()
+ step = 0
+
+ def record_logs(logs: InferenceLogs):
+ nonlocal step
+ for j, log in enumerate(logs):
+ if len(log) > 0:
+ if label != "":
+ log = {f"{label}/{k}": v for k, v in log.items()}
+ wandb.log(log, step=step + j)
+ step += len(logs)
+
+ return record_logs
+
+
+def run_inference(
+ predict: PredictFunction[PS, FD, SD],
+ data: InferenceDataABC[PS, FD],
+ aggregator: InferenceAggregatorABC[PS, SD],
+ writer: Optional[WriterABC[PS, SD]] = None,
+ record_logs: Optional[Callable[[InferenceLogs], None]] = None,
+):
+ """Run extended inference loop given initial condition and forcing data.
+
+ Args:
+ predict: The prediction function to use.
+ data: Provides an initial condition and appropriately aligned windows of
+ forcing data.
+ aggregator: Aggregator for collecting and reducing metrics.
+ writer: Data writer for saving the inference results to disk.
+ record_logs: Function for recording logs. By default, logs are recorded to
+ wandb.
+ """
+ if record_logs is None:
+ record_logs = get_record_to_wandb(label="inference")
+ if writer is None:
+ writer = NullDataWriter()
+ timer = GlobalTimer.get_instance()
+ with torch.no_grad():
+ looper = Looper(predict=predict, data=data)
+ with timer.context("aggregator"):
+ logs = aggregator.record_initial_condition(
+ initial_condition=data.initial_condition,
+ )
+ with timer.context("wandb_logging"):
+ record_logs(logs)
+ with timer.context("data_writer"):
+ writer.write(data.initial_condition, "initial_condition.nc")
+ n_windows = len(looper)
+ for i, batch in enumerate(looper):
+ logging.info(
+ f"Inference: processing output from window {i + 1} of {n_windows}."
+ )
+ with timer.context("data_writer"):
+ writer.append_batch(
+ batch=batch,
+ )
+ with timer.context("aggregator"):
+ logs = aggregator.record_batch(
+ data=batch,
+ )
+ with timer.context("wandb_logging"):
+ record_logs(logs)
+ with timer.context("data_writer"):
+ prognostic_state = looper.get_prognostic_state()
+ writer.write(prognostic_state, "restart.nc")
diff --git a/fme/fme/core/generics/optimization.py b/fme/fme/core/generics/optimization.py
new file mode 100644
index 0000000..4715c00
--- /dev/null
+++ b/fme/fme/core/generics/optimization.py
@@ -0,0 +1,50 @@
+import abc
+import contextlib
+
+import torch
+from torch import nn
+
+
+class OptimizationABC(abc.ABC):
+ @contextlib.contextmanager
+ @abc.abstractmethod
+ def autocast(self): ...
+
+ @property
+ @abc.abstractmethod
+ def learning_rate(self) -> float: ...
+
+ @abc.abstractmethod
+ def set_mode(self, module: nn.Module):
+ """
+ Sets the mode of the module to train.
+ """
+ ...
+
+ @abc.abstractmethod
+ def step_scheduler(self, valid_loss: float):
+ """
+ Step the scheduler.
+
+ Args:
+ valid_loss: The validation loss. Used in schedulers which change the
+ learning rate based on whether the validation loss is decreasing.
+ """
+ ...
+
+ @abc.abstractmethod
+ def step_weights(self, loss: torch.Tensor): ...
+
+ @abc.abstractmethod
+ def get_state(self):
+ """
+ Returns state as a serializable data structure.
+ """
+ ...
+
+ @abc.abstractmethod
+ def load_state(self, state):
+ """
+ Loads state from a serializable data structure.
+ """
+ ...
diff --git a/fme/fme/core/generics/test_looper.py b/fme/fme/core/generics/test_looper.py
new file mode 100644
index 0000000..04e671f
--- /dev/null
+++ b/fme/fme/core/generics/test_looper.py
@@ -0,0 +1,503 @@
+import datetime
+import unittest.mock
+from collections import namedtuple
+from typing import Callable, Iterable, Optional, Tuple
+
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+
+import fme
+from fme.ace.data_loading.batch_data import BatchData, PrognosticState
+from fme.ace.stepper import SingleModuleStepperConfig
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.device import get_device
+from fme.core.generics.data import SimpleInferenceData
+from fme.core.generics.inference import (
+ Looper,
+ PredictFunction,
+ get_record_to_wandb,
+ run_inference,
+)
+from fme.core.gridded_ops import LatLonOperations
+from fme.core.loss import WeightedMappingLossConfig
+from fme.core.normalizer import NormalizationConfig
+from fme.core.registry.module import ModuleSelector
+from fme.core.testing.wandb import mock_wandb
+from fme.core.timing import GlobalTimer
+from fme.core.typing_ import TensorDict, TensorMapping
+
+SphericalData = namedtuple("SphericalData", ["data", "area_weights", "vertical_coord"])
+
+
+def get_data(
+ names: Iterable[str], shape: Tuple[int, int, int, int, int]
+) -> SphericalData:
+ data = {}
+ n_lat = shape[2]
+ n_lon = shape[3]
+ shape_without_z = shape[:-1]
+ nz = shape[-1]
+ lats = torch.linspace(-89.5, 89.5, n_lat)
+ for name in names:
+ data[name] = torch.rand(*shape_without_z, device=fme.get_device())
+ area_weights = fme.spherical_area_weights(lats, n_lon).to(fme.get_device())
+ ak, bk = torch.arange(nz), torch.arange(nz)
+ vertical_coord = HybridSigmaPressureCoordinate(ak, bk)
+ return SphericalData(data, area_weights, vertical_coord)
+
+
+def get_scalar_data(names, value):
+ return {n: np.array([value], dtype=np.float32) for n in names}
+
+
+class MockLoader(torch.utils.data.DataLoader):
+ def __init__(
+ self,
+ shape: tuple,
+ names: Iterable[str],
+ n_windows: int,
+ time: Optional[xr.DataArray] = None,
+ ):
+ device = fme.get_device()
+ self._data = {n: torch.rand(*shape, device=device) for n in names}
+ if time is None:
+ self._time = xr.DataArray(np.zeros(shape[:2]), dims=["sample", "time"])
+ elif time.shape != shape[:2]:
+ raise ValueError(
+ "Time shape must match the first two dimensions of the data."
+ )
+ else:
+ self._time = time
+ self._n_windows = n_windows
+ self._current_window = 0
+
+ def __iter__(self):
+ return self
+
+ def __len__(self) -> int:
+ return self._n_windows
+
+ def __next__(self) -> BatchData:
+ if self._current_window < self._n_windows:
+ self._current_window += 1
+ return BatchData.new_on_device(
+ data=self._data,
+ time=self._time
+ + (self._current_window - 1) * (self._time.shape[1] - 1),
+ )
+ else:
+ raise StopIteration
+
+
+def _get_stepper():
+ class ChannelSum(torch.nn.Module):
+ def forward(self, x):
+ summed = torch.sum(x, dim=-3, keepdim=True)
+ diagnostic = torch.rand_like(summed)
+ return torch.concat([x, diagnostic], dim=-3)
+
+ n_samples = 2
+ n_time = 5
+ n_lat = 2
+ n_lon = 4
+ nz = 3
+ shape = (n_samples, n_time, n_lat, n_lon, nz)
+ in_names = ["forcing", "prognostic"]
+ out_names = ["prognostic", "diagnostic"]
+ all_names = list(set(in_names + out_names))
+
+ spherical_data = get_data(all_names, shape)
+ time = xr.DataArray(
+ np.repeat(np.expand_dims(np.arange(n_time), axis=0), n_samples, axis=0),
+ dims=["sample", "time"],
+ )
+
+ img_shape = spherical_data.data[in_names[0]].shape[2:]
+ gridded_operations = LatLonOperations(spherical_data.area_weights)
+ vertical_coordinate = spherical_data.vertical_coord
+ config = SingleModuleStepperConfig(
+ builder=ModuleSelector(type="prebuilt", config={"module": ChannelSum()}),
+ in_names=in_names,
+ out_names=out_names,
+ normalization=NormalizationConfig(
+ means=get_scalar_data(all_names, 0.0),
+ stds=get_scalar_data(all_names, 1.0),
+ ),
+ loss=WeightedMappingLossConfig(),
+ )
+ stepper = config.get_stepper(
+ img_shape,
+ gridded_operations,
+ vertical_coordinate,
+ datetime.timedelta(seconds=1),
+ )
+ return stepper, spherical_data, time, in_names, out_names
+
+
+def test_looper():
+ stepper, spherical_data, time, in_names, out_names = _get_stepper()
+ forcing_names = set(in_names) - set(out_names)
+ shape = spherical_data.data[in_names[0]].shape
+ initial_condition = BatchData.new_on_device(
+ data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
+ time=time[:, 0:1],
+ ).get_start(
+ prognostic_names=stepper.prognostic_names,
+ n_ic_timesteps=1,
+ )
+ loader = MockLoader(shape, forcing_names, 3, time=time)
+ looper = Looper(
+ predict=stepper.predict,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+
+ expected_output_shape = (shape[0], shape[1] - 1, shape[2], shape[3])
+ for batch in looper:
+ assert set(out_names) == set(batch.data)
+ for name in out_names:
+ assert batch.data[name].shape == expected_output_shape
+
+
+def test_looper_paired():
+ stepper, spherical_data, time, in_names, out_names = _get_stepper()
+ forcing_names = set(in_names) - set(out_names)
+ shape = spherical_data.data[in_names[0]].shape
+ initial_condition = BatchData.new_on_device(
+ data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
+ time=time[:, 0:1],
+ ).get_start(
+ prognostic_names=stepper.prognostic_names,
+ n_ic_timesteps=1,
+ )
+ loader = MockLoader(shape, forcing_names, 3, time=time)
+ looper = Looper(
+ predict=stepper.predict_paired,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+
+ expected_output_shape = (shape[0], shape[1] - 1, shape[2], shape[3])
+ for batch in looper:
+ assert set(out_names) == set(batch.prediction)
+ assert set(forcing_names) == set(batch.target)
+ for name in out_names:
+ assert batch.prediction[name].shape == expected_output_shape
+ for name in forcing_names:
+ assert batch.target[name].shape == expected_output_shape
+
+
+def _mock_compute_derived_quantities(data, forcing_data):
+ data_name = list(data)[0]
+ forcing_name = list(forcing_data)[0]
+ derived = {"derived": data[data_name] + forcing_data[forcing_name]}
+ return {**data, **derived}
+
+
+def test_looper_paired_with_derived_variables():
+ mock_derive_func = unittest.mock.MagicMock(
+ side_effect=_mock_compute_derived_quantities
+ )
+ stepper, spherical_data, time, in_names, out_names = _get_stepper()
+ stepper.derive_func = mock_derive_func
+ forcing_names = set(in_names) - set(out_names)
+ shape = spherical_data.data[in_names[0]].shape
+ initial_condition = BatchData.new_on_device(
+ {n: spherical_data.data[n][:, :1] for n in spherical_data.data},
+ time=time[:, 0:1],
+ ).get_start(
+ prognostic_names=stepper.prognostic_names,
+ n_ic_timesteps=1,
+ )
+ loader = MockLoader(shape, forcing_names, 2, time=time)
+ looper = Looper(
+ predict=stepper.predict_paired,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+
+ for batch in looper:
+ assert "derived" in batch.prediction
+ assert "derived" in batch.target
+ mock_derive_func.assert_called()
+
+
+def test_looper_paired_with_target_data():
+ stepper, spherical_data, time, in_names, out_names = _get_stepper()
+ all_names = list(set(in_names + out_names))
+ shape = spherical_data.data[in_names[0]].shape
+ initial_condition = BatchData.new_on_device(
+ data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
+ time=time[:, 0:1],
+ ).get_start(
+ prognostic_names=stepper.prognostic_names,
+ n_ic_timesteps=1,
+ )
+ loader = MockLoader(shape, all_names, 2, time=time)
+ looper = Looper(
+ predict=stepper.predict_paired,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+
+ for batch in looper:
+ assert set(out_names) == set(batch.prediction)
+ assert set(all_names) == set(batch.target)
+
+
+def test_looper_paired_with_target_data_and_derived_variables():
+ mock_derive_func = unittest.mock.MagicMock(
+ side_effect=_mock_compute_derived_quantities
+ )
+ stepper, spherical_data, time, in_names, out_names = _get_stepper()
+ stepper.derive_func = mock_derive_func
+ all_names = list(set(in_names + out_names))
+ shape = spherical_data.data[in_names[0]].shape
+ initial_condition = BatchData.new_on_device(
+ data={n: spherical_data.data[n][:, :1] for n in spherical_data.data},
+ time=time[:, 0:1],
+ ).get_start(
+ prognostic_names=stepper.prognostic_names,
+ n_ic_timesteps=1,
+ )
+ loader = MockLoader(shape, all_names, 2, time=time)
+ looper = Looper(
+ predict=stepper.predict_paired,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+
+ for batch in looper:
+ assert set(out_names + ["derived"]) == set(batch.prediction)
+ assert set(all_names + ["derived"]) == set(batch.target)
+ mock_derive_func.assert_called()
+
+
+def get_batch_data(
+ start_time,
+ n_timesteps,
+):
+ n_samples = 1
+ n_lat = 3
+ n_lon = 4
+ time_values = torch.arange(
+ start_time, start_time + n_timesteps, device=get_device()
+ )[None, :, None, None]
+ time_axis = torch.broadcast_to(
+ start_time + torch.arange(n_timesteps)[None, :], (n_samples, n_timesteps)
+ )
+ time = xr.DataArray(time_axis, dims=["sample", "time"])
+ return BatchData.new_on_device(
+ data={
+ "var": torch.broadcast_to(
+ time_values, (n_samples, n_timesteps, n_lat, n_lon)
+ )
+ },
+ time=time,
+ )
+
+
+class PlusOneStepper:
+ def __init__(
+ self,
+ n_ic_timesteps: int,
+ derive_func: Optional[
+ Callable[[TensorMapping, TensorMapping], TensorDict]
+ ] = None,
+ ):
+ self.n_ic_timesteps = n_ic_timesteps
+ if derive_func is None:
+ self.derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = (
+ unittest.mock.MagicMock(side_effect=lambda x, y=None: dict(x))
+ )
+ else:
+ self.derive_func = derive_func
+ _: PredictFunction[ # for type checking
+ PrognosticState,
+ BatchData,
+ BatchData,
+ ] = self.predict
+
+ def predict(
+ self,
+ initial_condition: PrognosticState,
+ forcing: BatchData,
+ compute_derived_variables: bool = False,
+ ) -> Tuple[BatchData, PrognosticState]:
+ ic_state = initial_condition.as_batch_data()
+ n_forward_steps = forcing.time.shape[1] - self.n_ic_timesteps
+ out_tensor = torch.zeros(
+ ic_state.data["var"].shape[0],
+ n_forward_steps,
+ *ic_state.data["var"].shape[2:],
+ device=ic_state.data["var"].device,
+ dtype=ic_state.data["var"].dtype,
+ )
+ out_tensor[:, 0, ...] = ic_state.data["var"][:, -1, ...] + 1
+ for i in range(1, n_forward_steps):
+ out_tensor[:, i, ...] = out_tensor[:, i - 1, ...] + 1
+ data = BatchData.new_on_device(
+ data={"var": out_tensor},
+ time=forcing.time[:, self.n_ic_timesteps :],
+ )
+ if compute_derived_variables:
+ data = data.compute_derived_variables(
+ derive_func=self.derive_func, forcing_data=data
+ )
+ return data, data.get_end(["var"], self.n_ic_timesteps)
+
+ def get_forward_data(
+ self,
+ forcing: BatchData,
+ compute_derived_variables: bool = False,
+ ) -> BatchData:
+ if compute_derived_variables:
+ forcing = forcing.compute_derived_variables(
+ derive_func=self.derive_func, forcing_data=forcing
+ )
+ return forcing.remove_initial_condition(self.n_ic_timesteps)
+
+
+def test_looper_simple_batch_data():
+ n_ic_timesteps = 1
+ n_forward_steps = 2
+ n_iterations = 10
+ mock_derive_func = unittest.mock.MagicMock(
+ side_effect=lambda batch_data, forcing_data: batch_data
+ )
+ stepper = PlusOneStepper(
+ n_ic_timesteps=n_ic_timesteps, derive_func=mock_derive_func
+ )
+ initial_condition = get_batch_data(
+ 0,
+ n_timesteps=n_ic_timesteps,
+ ).get_start(prognostic_names=["var"], n_ic_timesteps=n_ic_timesteps)
+ loader = [
+ get_batch_data(
+ i,
+ n_ic_timesteps + n_forward_steps,
+ )
+ for i in range(0, n_iterations * n_forward_steps, n_forward_steps)
+ ]
+
+ mock_predict = unittest.mock.MagicMock(side_effect=stepper.predict)
+ with unittest.mock.patch.object(stepper, "predict", mock_predict):
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ looper = Looper(
+ predict=stepper.predict,
+ data=SimpleInferenceData(initial_condition, loader),
+ )
+ for i, batch in enumerate(looper):
+ for j in range(batch.time.shape[1]):
+ assert torch.allclose(
+ batch.data["var"][:, j, ...],
+ torch.as_tensor(n_ic_timesteps + i * n_forward_steps + j),
+ )
+ times = timer.get_durations()
+ assert times["data_loading"] > 0
+ assert mock_derive_func.call_count == n_iterations
+ assert mock_predict.call_count == n_iterations
+ # we mocked out the implicit calls that happen in .predict
+ # if this changed, update the test
+ assert "forward_prediction" not in times
+ assert "compute_derived_variables" not in times
+
+
+def get_mock_aggregator(
+ n_ic_timesteps: int,
+) -> unittest.mock.MagicMock:
+ mock_aggregator = unittest.mock.MagicMock()
+
+ # record_batch will start at step n_ic_timesteps
+ i = n_ic_timesteps
+
+ def record_batch_side_effect(
+ data: BatchData,
+ ):
+ nonlocal i
+ ret = [{"step": j} for j in range(i, i + data.time.shape[1])]
+ i += data.time.shape[1]
+ return ret
+
+ mock_aggregator = unittest.mock.MagicMock()
+ mock_aggregator.record_initial_condition = unittest.mock.MagicMock(
+ return_value=[{"step": j} for j in range(n_ic_timesteps)]
+ )
+
+ def get_summary_logs_side_effect():
+ # we expect this gets called _outside_ of run_inference
+ raise ValueError("should not be called")
+
+ mock_aggregator.get_summary_logs = unittest.mock.MagicMock(
+ side_effect=get_summary_logs_side_effect
+ )
+ mock_aggregator.record_batch = unittest.mock.MagicMock(
+ side_effect=record_batch_side_effect
+ )
+ return mock_aggregator
+
+
+def get_mock_writer() -> unittest.mock.MagicMock:
+ mock_writer = unittest.mock.MagicMock()
+ return mock_writer
+
+
+@pytest.mark.parametrize(
+ "n_ic_timesteps, n_forward_steps, n_iterations",
+ [
+ pytest.param(1, 2, 5, id="n_ic_timesteps=1"),
+ pytest.param(2, 2, 5, id="n_ic_timesteps=2"),
+ ],
+)
+def test_run_inference_simple(
+ n_ic_timesteps: int, n_forward_steps: int, n_iterations: int
+):
+ mock_derive_func = unittest.mock.MagicMock(
+ side_effect=lambda batch_data, forcing_data: batch_data
+ )
+ stepper = PlusOneStepper(
+ n_ic_timesteps=n_ic_timesteps, derive_func=mock_derive_func
+ )
+ initial_condition = get_batch_data(
+ 0,
+ n_timesteps=n_ic_timesteps,
+ ).get_start(prognostic_names=["var"], n_ic_timesteps=n_ic_timesteps)
+ loader = [
+ get_batch_data(
+ i,
+ n_ic_timesteps + n_forward_steps,
+ )
+ for i in range(0, n_iterations * n_forward_steps, n_forward_steps)
+ ]
+ mock_writer = get_mock_writer()
+ mock_aggregator = get_mock_aggregator(n_ic_timesteps)
+
+ with GlobalTimer():
+ with mock_wandb() as wandb:
+ wandb.configure(log_to_wandb=True)
+ record_logs = unittest.mock.MagicMock(
+ side_effect=get_record_to_wandb("inference")
+ ) # this init must be within mock_wandb context
+ run_inference(
+ predict=stepper.predict,
+ data=SimpleInferenceData(initial_condition, loader),
+ writer=mock_writer,
+ aggregator=mock_aggregator,
+ record_logs=record_logs,
+ )
+ wandb_logs = wandb.get_logs()
+ timer = GlobalTimer.get_instance()
+ times = timer.get_durations()
+ assert times["wandb_logging"] > 0
+ assert times["data_writer"] > 0
+ assert times["aggregator"] > 0
+ assert mock_writer.write.call_count == 2
+ assert mock_aggregator.record_initial_condition.call_count == 1
+ assert mock_writer.append_batch.call_count == n_iterations
+ assert mock_aggregator.record_batch.call_count == n_iterations
+ assert len(wandb_logs) == n_ic_timesteps + n_iterations * n_forward_steps
+ assert wandb_logs == [
+ {"inference/step": i}
+ for i in range(n_ic_timesteps + n_iterations * n_forward_steps)
+ ]
+ assert (
+ record_logs.call_count == n_iterations + 1
+ ) # +1 for the initial condition
diff --git a/fme/fme/core/generics/test_trainer.py b/fme/fme/core/generics/test_trainer.py
new file mode 100644
index 0000000..71d6635
--- /dev/null
+++ b/fme/fme/core/generics/test_trainer.py
@@ -0,0 +1,638 @@
+import contextlib
+import dataclasses
+import itertools
+import os
+import unittest.mock
+from typing import Any, Dict, Optional, Tuple, Type, TypeVar, cast
+
+import numpy as np
+import pytest
+import torch
+
+from fme.ace.data_loading.batch_data import DataLoader
+from fme.ace.stepper import TrainOutputABC, TrainStepperABC
+from fme.core.ema import EMATracker
+from fme.core.generics.aggregator import (
+ AggregatorABC,
+ InferenceAggregatorABC,
+ InferenceLog,
+ InferenceLogs,
+)
+from fme.core.generics.data import GriddedDataABC, InferenceDataABC
+from fme.core.generics.optimization import OptimizationABC
+from fme.core.generics.trainer import (
+ AggregatorBuilderABC,
+ CheckpointPaths,
+ TrainConfigProtocol,
+ Trainer,
+)
+from fme.core.optimization import Optimization
+from fme.core.scheduler import SchedulerConfig
+from fme.core.typing_ import Slice, TensorDict, TensorMapping
+
+
+class PSType:
+ pass
+
+
+class BDType:
+ pass
+
+
+class FDType:
+ pass
+
+
+class SDType:
+ pass
+
+
+class TrainOutput(TrainOutputABC):
+ def get_metrics(self) -> Dict[str, torch.Tensor]:
+ return {}
+
+
+class TrainData(GriddedDataABC[BDType]):
+ @property
+ def batch_size(self) -> int:
+ return 1
+
+ @property
+ def loader(self) -> DataLoader[BDType]:
+ return [BDType() for _ in range(self.n_batches)]
+
+ @property
+ def n_samples(self) -> int:
+ return 3
+
+ @property
+ def n_batches(self) -> int:
+ return 5
+
+ @property
+ def n_forward_steps(self) -> int:
+ return 1
+
+ def __init__(self):
+ self._set_epoch = unittest.mock.MagicMock()
+ self._log_info = unittest.mock.MagicMock()
+
+ def set_epoch(self, epoch: int) -> None:
+ self._set_epoch(epoch)
+
+ def log_info(self, name: str) -> None:
+ self._log_info(name)
+
+ @property
+ def set_epoch_mock(self) -> unittest.mock.Mock:
+ return self._set_epoch
+
+ @property
+ def log_info_mock(self) -> unittest.mock.Mock:
+ return self._log_info
+
+
+class InferenceData(InferenceDataABC[PSType, FDType]):
+ def __init__(self, n_time_windows: int = 1):
+ self.n_time_windows = n_time_windows
+
+ @property
+ def initial_condition(self) -> PSType:
+ return PSType()
+
+ @property
+ def loader(self) -> DataLoader[FDType]:
+ return [FDType() for _ in range(self.n_time_windows)]
+
+ @property
+ def n_window_forward_steps(self) -> int:
+ return 1
+
+
+class TrainStepper(TrainStepperABC[PSType, BDType, FDType, SDType, TrainOutput]):
+ SelfType = TypeVar("SelfType", bound="TrainStepper")
+
+ def __init__(
+ self,
+ state: Optional[Dict[str, Any]] = None,
+ ):
+ self._modules = torch.nn.ModuleList([torch.nn.Linear(1, 1, bias=False)])
+ self._modules[0].weight.data.fill_(0.0)
+ if state is not None:
+ self._state = state
+ else:
+ self._state = {}
+ self.loaded_state: Optional[Dict[str, Any]] = None
+
+ def get_state(self) -> Dict[str, Any]:
+ return {**self._state, "modules": self._modules.state_dict()}
+
+ def load_state(self, state: Dict[str, Any]) -> None:
+ self._state = state
+ self.loaded_state = state
+ self._modules.load_state_dict(state["modules"])
+
+ @classmethod
+ def from_state(cls: Type[SelfType], state: Dict[str, Any]) -> SelfType:
+ ret = cls()
+ ret.load_state(state)
+ return ret
+
+ @property
+ def modules(self) -> torch.nn.ModuleList:
+ return self._modules
+
+ @property
+ def n_ic_timesteps(self) -> int:
+ return 1
+
+ def normalize(self, data: TensorMapping) -> TensorDict:
+ return dict(data)
+
+ def predict_paired(
+ self,
+ initial_condition: PSType,
+ forcing: FDType,
+ compute_derived_variables: bool = False,
+ ) -> Tuple[SDType, PSType]:
+ return SDType(), PSType()
+
+ def train_on_batch(
+ self,
+ batch: BDType,
+ optimization: OptimizationABC,
+ compute_derived_variables: bool = False,
+ ) -> TrainOutput:
+ optimization.step_weights(torch.tensor(float("inf")))
+ return TrainOutput()
+
+
+@dataclasses.dataclass
+class Config:
+ experiment_dir: str = "test_experiment_dir"
+ checkpoint_dir: str = "test_checkpoint_dir"
+ max_epochs: int = 2
+ save_checkpoint: bool = True
+ validate_using_ema: bool = True
+ log_train_every_n_batches: int = 1
+ inference_n_forward_steps: int = 1
+ checkpoint_save_epochs: Optional[Slice] = None
+ ema_checkpoint_save_epochs: Optional[Slice] = None
+ segment_epochs: Optional[int] = None
+ clean_wandb = unittest.mock.MagicMock()
+
+ def __post_init__(self):
+ self.get_inference_epochs = unittest.mock.MagicMock(
+ return_value=[i for i in range(self.max_epochs)]
+ )
+
+
+_: TrainConfigProtocol = Config()
+
+
+class TrainAggregator(AggregatorABC[TrainOutput]):
+ def __init__(self, train_loss: float):
+ self.train_loss = train_loss
+
+ def record_batch(self, batch: TrainOutput) -> None:
+ pass
+
+ def get_logs(self, label: str) -> Dict[str, Any]:
+ return {f"{label}/mean/loss": self.train_loss}
+
+
+class ValidationAggregator(AggregatorABC[TrainOutput]):
+ def __init__(self, validation_loss: float):
+ self.validation_loss = validation_loss
+
+ def record_batch(self, batch: TrainOutput) -> None:
+ pass
+
+ def get_logs(self, label: str) -> Dict[str, Any]:
+ return {f"{label}/mean/loss": self.validation_loss}
+
+
+class InferenceAggregator(InferenceAggregatorABC[PSType, SDType]):
+ def __init__(self, inference_loss: float):
+ self.inference_loss = inference_loss
+
+ def record_batch(self, data: SDType) -> InferenceLogs:
+ return [{}]
+
+ def record_initial_condition(self, initial_condition: PSType) -> InferenceLogs:
+ return [{}]
+
+ def get_summary_logs(self) -> InferenceLog:
+ return {"time_mean_norm/rmse/channel_mean": self.inference_loss}
+
+
+class AggregatorBuilder(AggregatorBuilderABC[PSType, TrainOutput, SDType]):
+ def __init__(
+ self,
+ train_losses: np.ndarray,
+ validation_losses: np.ndarray,
+ inference_losses: np.ndarray,
+ ):
+ self.train_losses = train_losses
+ self.validation_losses = validation_losses
+ self.inference_losses = inference_losses
+ self._train_calls = 0
+ self._validation_calls = 0
+ self._inference_calls = 0
+
+ def get_train_aggregator(self) -> AggregatorABC[TrainOutput]:
+ ret = TrainAggregator(self.train_losses[self._train_calls])
+ self._train_calls += 1
+ return ret
+
+ def get_validation_aggregator(self) -> AggregatorABC[TrainOutput]:
+ ret = ValidationAggregator(self.validation_losses[self._validation_calls])
+ self._validation_calls += 1
+ return ret
+
+ def get_inference_aggregator(self) -> InferenceAggregatorABC[PSType, SDType]:
+ ret = InferenceAggregator(self.inference_losses[self._inference_calls])
+ self._inference_calls += 1
+ return ret
+
+
+def get_trainer(
+ tmp_path: str,
+ checkpoint_save_epochs: Optional[Slice] = None,
+ segment_epochs: Optional[int] = None,
+ max_epochs: int = 8,
+ checkpoint_dir: Optional[str] = None,
+ stepper_state: Optional[Dict[str, Any]] = None,
+ train_losses: Optional[np.ndarray] = None,
+ validation_losses: Optional[np.ndarray] = None,
+ inference_losses: Optional[np.ndarray] = None,
+ stepper_module_values: Optional[np.ndarray] = None,
+ ema_decay: float = 0.9999,
+ validate_using_ema: bool = True,
+) -> Tuple[TrainConfigProtocol, Trainer]:
+ if checkpoint_dir is None:
+ checkpoint_dir = os.path.join(tmp_path, "checkpoints")
+ if train_losses is None:
+ train_losses = np.zeros(max_epochs)
+ if validation_losses is None:
+ validation_losses = np.zeros(max_epochs)
+ if inference_losses is None:
+ inference_losses = np.zeros(max_epochs)
+ if stepper_module_values is None:
+ stepper_module_values = np.zeros(max_epochs)
+ train_data = TrainData()
+ validation_data = TrainData()
+ inference_data = InferenceData()
+ stepper = TrainStepper(state=stepper_state)
+
+ def build_optimization(modules: torch.nn.ModuleList) -> Optimization:
+ if len(modules) != 1:
+ raise ValueError("Expected 1 linear module with 1 weight")
+ if not isinstance(modules[0], torch.nn.Linear):
+ raise ValueError("Expected a linear module")
+ module = modules[0]
+ if module.weight.numel() != 1:
+ raise ValueError("Expected a linear module with 1 weight")
+ i = 0
+
+ opt = Optimization(
+ parameters=itertools.chain(*[module.parameters() for module in modules]),
+ optimizer_type="Adam",
+ lr=0.01,
+ max_epochs=max_epochs,
+ scheduler=SchedulerConfig(),
+ enable_automatic_mixed_precision=False,
+ kwargs={},
+ )
+ original_step_scheduler = opt.step_scheduler
+
+ def step_scheduler_side_effect(*args, **kwargs):
+ original_step_scheduler(*args, **kwargs)
+ nonlocal i
+ i += 1
+
+ opt.step_scheduler = unittest.mock.MagicMock( # type: ignore
+ side_effect=step_scheduler_side_effect
+ )
+
+ def step_weights_side_effect(*args, **kwargs):
+ if stepper_module_values is None:
+ raise ValueError("stepper_module_values is None")
+ module.weight.data.fill_(stepper_module_values[i])
+
+ opt.step_weights = unittest.mock.MagicMock(side_effect=step_weights_side_effect) # type: ignore
+ return opt
+
+ def build_ema(modules: torch.nn.ModuleList) -> EMATracker:
+ return EMATracker(modules, decay=ema_decay)
+
+ config: TrainConfigProtocol = Config(
+ experiment_dir=tmp_path,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_save_epochs=checkpoint_save_epochs,
+ segment_epochs=segment_epochs,
+ max_epochs=max_epochs,
+ validate_using_ema=validate_using_ema,
+ )
+ aggregator_builder = AggregatorBuilder(
+ train_losses=train_losses,
+ validation_losses=validation_losses,
+ inference_losses=inference_losses,
+ )
+ callback = unittest.mock.MagicMock()
+ return config, Trainer(
+ train_data=train_data,
+ validation_data=validation_data,
+ inference_data=inference_data,
+ stepper=stepper,
+ build_optimization=build_optimization,
+ build_ema=build_ema,
+ config=config,
+ aggregator_builder=aggregator_builder,
+ end_of_batch_callback=callback,
+ do_gc_collect=False, # for much faster tests
+ )
+
+
+@pytest.mark.parametrize(
+ "checkpoint_save_epochs",
+ [None, Slice(start=2, stop=3), Slice(start=1, step=2)],
+)
+def test_trainer(tmp_path: str, checkpoint_save_epochs: Optional[Slice]):
+ config, trainer = get_trainer(tmp_path, checkpoint_save_epochs, max_epochs=4)
+ trainer.train()
+ assert os.path.exists(config.experiment_dir)
+ assert os.path.exists(config.checkpoint_dir)
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.latest_checkpoint_path)
+ assert os.path.exists(paths.best_checkpoint_path)
+ assert os.path.exists(paths.best_inference_checkpoint_path)
+ assert os.path.exists(paths.ema_checkpoint_path)
+ save_epochs = list(range(config.max_epochs))
+ if checkpoint_save_epochs is not None:
+ save_epochs = save_epochs[checkpoint_save_epochs.slice]
+ else:
+ save_epochs = []
+ for i in range(config.max_epochs):
+ if i in save_epochs:
+ assert os.path.exists(paths.epoch_checkpoint_path(i))
+ else:
+ assert not os.path.exists(paths.epoch_checkpoint_path(i))
+ assert not os.path.exists(paths.ema_epoch_checkpoint_path(i))
+ train_data = cast(TrainData, trainer.train_data)
+ valid_data = cast(TrainData, trainer.valid_data)
+ assert train_data.set_epoch_mock.mock_calls == [
+ unittest.mock.call(i) for i in range(config.max_epochs)
+ ]
+ assert valid_data.set_epoch_mock.mock_calls == [] # no shuffling
+ assert train_data.log_info_mock.called
+ assert valid_data.log_info_mock.called
+
+
+@pytest.mark.parametrize("segment_epochs", [1, 2, 3])
+def test_segmented_trainer_runs_correct_epochs(tmp_path: str, segment_epochs: int):
+ max_epochs = 4
+ total_segments = max_epochs // segment_epochs
+ for i in range(total_segments):
+ config, trainer = get_trainer(
+ tmp_path,
+ checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all
+ # for speed, don't save per-epoch checkpoints for this test
+ checkpoint_save_epochs=Slice(start=0, stop=0),
+ segment_epochs=segment_epochs,
+ max_epochs=max_epochs,
+ )
+ trainer.train()
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.latest_checkpoint_path)
+ train_data = cast(TrainData, trainer.train_data)
+ assert train_data.set_epoch_mock.mock_calls == [
+ unittest.mock.call(i)
+ for i in range(
+ i * segment_epochs,
+ min((i + 1) * segment_epochs, config.max_epochs),
+ )
+ ]
+
+
+class TrainingInterrupted(Exception):
+ pass
+
+
+@contextlib.contextmanager
+def fail_after_calls_patch(object, method: str, call_count: int):
+ total_calls = 0
+ original_method = getattr(object, method)
+
+ def wrapper(*args, **kwargs):
+ nonlocal total_calls
+ total_calls += 1
+ if total_calls >= call_count:
+ raise TrainingInterrupted()
+ return original_method(*args, **kwargs)
+
+ with unittest.mock.patch.object(object, method) as mock:
+ mock.side_effect = wrapper
+ try:
+ yield mock
+ except TrainingInterrupted:
+ pass
+
+
+@pytest.mark.parametrize(
+ "interrupt_method",
+ ["train_one_epoch", "validate_one_epoch", "inference_one_epoch"],
+)
+def test_resume_after_interrupted_training(tmp_path: str, interrupt_method: str):
+ max_epochs = 4
+ calls_before_interrupt = 2
+ stepper_state = {"foo": "bar"}
+ config, trainer = get_trainer(
+ tmp_path,
+ stepper_state=stepper_state,
+ checkpoint_save_epochs=Slice(start=0, stop=0),
+ max_epochs=max_epochs,
+ )
+ with fail_after_calls_patch(trainer, interrupt_method, calls_before_interrupt):
+ trainer.train()
+ train_data = cast(TrainData, trainer.train_data)
+ assert train_data.set_epoch_mock.mock_calls == [
+ unittest.mock.call(i) for i in range(calls_before_interrupt)
+ ]
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.latest_checkpoint_path)
+ _, trainer = get_trainer(
+ tmp_path,
+ checkpoint_save_epochs=Slice(start=0, stop=0),
+ max_epochs=max_epochs,
+ stepper_state=stepper_state,
+ )
+ trainer.train()
+ train_data = cast(TrainData, trainer.train_data)
+ assert train_data.set_epoch_mock.mock_calls == [
+ unittest.mock.call(i) for i in range(calls_before_interrupt - 1, max_epochs)
+ ]
+ stepper = cast(TrainStepper, trainer.stepper)
+ assert stepper.loaded_state is not None
+ assert stepper.loaded_state["foo"] == "bar"
+ assert "modules" in stepper.loaded_state
+ assert len(stepper.loaded_state) == 2
+
+
+@pytest.mark.parametrize("ema_decay", [0.05, 0.99])
+@pytest.mark.parametrize("validate_using_ema", [True, False])
+def test_saves_correct_ema_checkpoints(
+ tmp_path: str, ema_decay: float, validate_using_ema: bool
+):
+ config, trainer = get_trainer(
+ tmp_path,
+ checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all
+ ema_decay=ema_decay,
+ validate_using_ema=validate_using_ema,
+ )
+ valid_loss = 0.1
+ inference_error = 0.2
+ trainer.stepper.modules[0].weight.data.fill_(1.0)
+ trainer._ema(model=trainer.stepper.modules)
+ trainer.save_all_checkpoints(valid_loss=valid_loss, inference_error=inference_error)
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.ema_checkpoint_path)
+ ema_checkpoint = torch.load(paths.ema_checkpoint_path)
+ ema_weight = 1.0 - min(ema_decay, 2.0 / 11.0)
+ np.testing.assert_allclose(
+ ema_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ ema_weight,
+ atol=1e-7,
+ )
+ assert ema_checkpoint["best_validation_loss"] == valid_loss
+ assert ema_checkpoint["best_inference_error"] == inference_error
+ assert os.path.exists(paths.latest_checkpoint_path)
+ latest_checkpoint = torch.load(paths.latest_checkpoint_path)
+ np.testing.assert_allclose(
+ latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ 1.0,
+ atol=1e-7,
+ )
+ assert latest_checkpoint["best_validation_loss"] == valid_loss
+ assert latest_checkpoint["best_inference_error"] == inference_error
+ if validate_using_ema:
+ best_weight = ema_weight
+ else:
+ best_weight = 1.0
+ assert os.path.exists(paths.best_checkpoint_path)
+ best_checkpoint = torch.load(paths.best_checkpoint_path)
+ assert best_checkpoint["best_validation_loss"] == valid_loss
+ assert best_checkpoint["best_inference_error"] == inference_error
+ np.testing.assert_allclose(
+ best_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ best_weight,
+ atol=1e-7,
+ )
+ best_inference_checkpoint = torch.load(paths.best_inference_checkpoint_path)
+ assert best_inference_checkpoint["best_validation_loss"] == valid_loss
+ assert best_inference_checkpoint["best_inference_error"] == inference_error
+ np.testing.assert_allclose(
+ best_inference_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ best_weight,
+ atol=1e-7,
+ )
+
+
+@pytest.mark.parametrize(
+ "segment_epochs, best_val_epoch, best_inference_epoch",
+ [
+ (None, 1, 1),
+ (None, 3, 5),
+ (2, 3, 5),
+ (2, 5, 3),
+ (2, 4, 6),
+ (2, 6, 4),
+ ],
+)
+def test_saves_correct_non_ema_epoch_checkpoints(
+ tmp_path: str,
+ segment_epochs: Optional[int],
+ best_val_epoch: int,
+ best_inference_epoch: int,
+):
+ max_epochs = 10
+ if segment_epochs is None:
+ total_segments = 1
+ segment_epochs_value = max_epochs
+ else:
+ total_segments = max_epochs // segment_epochs
+ segment_epochs_value = segment_epochs
+ train_losses = np.random.rand(max_epochs) + 0.01
+ val_losses = np.random.rand(max_epochs) + 0.01
+ inference_losses = np.random.rand(max_epochs) + 0.01
+ val_losses[best_val_epoch - 1] = 0.0
+ inference_losses[best_inference_epoch - 1] = 0.0
+ module_values = np.random.rand(max_epochs)
+ for i in range(total_segments):
+ config, trainer = get_trainer(
+ tmp_path,
+ checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all
+ # for speed, don't save per-epoch checkpoints for this test
+ checkpoint_save_epochs=Slice(start=0, stop=0),
+ segment_epochs=segment_epochs,
+ max_epochs=max_epochs,
+ train_losses=train_losses[
+ i * segment_epochs_value : (i + 1) * segment_epochs_value
+ ],
+ validation_losses=val_losses[
+ i * segment_epochs_value : (i + 1) * segment_epochs_value
+ ],
+ inference_losses=inference_losses[
+ i * segment_epochs_value : (i + 1) * segment_epochs_value
+ ],
+ stepper_module_values=module_values[
+ i * segment_epochs_value : (i + 1) * segment_epochs_value
+ ],
+ validate_using_ema=False,
+ )
+ trainer.train()
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.latest_checkpoint_path)
+ train_data = cast(TrainData, trainer.train_data)
+ assert train_data.set_epoch_mock.mock_calls == [
+ unittest.mock.call(i)
+ for i in range(
+ i * segment_epochs_value,
+ min((i + 1) * segment_epochs_value, config.max_epochs),
+ )
+ ]
+ latest_checkpoint = torch.load(paths.latest_checkpoint_path)
+ assert latest_checkpoint["epoch"] == min(
+ max_epochs, (i + 1) * segment_epochs_value
+ )
+ np.testing.assert_allclose(
+ latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ module_values[min((i + 1) * segment_epochs_value - 1, max_epochs - 1)],
+ )
+ paths = CheckpointPaths(config.checkpoint_dir)
+ assert os.path.exists(paths.latest_checkpoint_path)
+ assert os.path.exists(paths.best_checkpoint_path)
+ assert os.path.exists(paths.best_inference_checkpoint_path)
+ assert os.path.exists(paths.ema_checkpoint_path)
+ best_checkpoint = torch.load(paths.best_checkpoint_path)
+ assert best_checkpoint["epoch"] == best_val_epoch
+ assert best_checkpoint["best_validation_loss"] == 0.0
+ assert best_checkpoint["best_inference_error"] == np.min(
+ inference_losses[:best_val_epoch]
+ )
+ np.testing.assert_allclose(
+ best_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ module_values[best_val_epoch - 1],
+ )
+ best_inference_checkpoint = torch.load(paths.best_inference_checkpoint_path)
+ assert best_inference_checkpoint["epoch"] == best_inference_epoch
+ assert best_inference_checkpoint["best_validation_loss"] == np.min(
+ val_losses[:best_inference_epoch]
+ )
+ assert best_inference_checkpoint["best_inference_error"] == 0.0
+ latest_checkpoint = torch.load(paths.latest_checkpoint_path)
+ assert latest_checkpoint["epoch"] == max_epochs
+ np.testing.assert_allclose(
+ latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(),
+ module_values[-1],
+ )
diff --git a/fme/fme/core/generics/train_stepper.py b/fme/fme/core/generics/train_stepper.py
new file mode 100644
index 0000000..e9db6c9
--- /dev/null
+++ b/fme/fme/core/generics/train_stepper.py
@@ -0,0 +1,63 @@
+import abc
+from typing import Any, Dict, Generic, Type, TypeVar
+
+from torch import nn
+
+from fme.core.generics.inference import PredictFunction
+from fme.core.generics.optimization import OptimizationABC
+from fme.core.typing_ import TensorDict
+
+TO = TypeVar("TO", bound="TrainOutputABC") # train output
+
+
+class TrainOutputABC(abc.ABC):
+ @abc.abstractmethod
+ def get_metrics(self) -> TensorDict:
+ pass
+
+
+PS = TypeVar("PS") # prognostic state
+BD = TypeVar("BD") # batch data
+FD = TypeVar("FD") # forcing data
+SD = TypeVar("SD") # stepped data
+
+
+class TrainStepperABC(abc.ABC, Generic[PS, BD, FD, SD, TO]):
+ SelfType = TypeVar("SelfType", bound="TrainStepperABC")
+
+ @abc.abstractmethod
+ def train_on_batch(
+ self,
+ data: BD,
+ optimization: OptimizationABC,
+ compute_derived_variables: bool = False,
+ ) -> TO:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def modules(self) -> nn.ModuleList:
+ pass
+
+ @abc.abstractmethod
+ def get_state(self) -> Dict[str, Any]:
+ pass
+
+ @abc.abstractmethod
+ def load_state(self, state: Dict[str, Any]) -> None:
+ pass
+
+ @classmethod
+ @abc.abstractmethod
+ def from_state(cls: Type[SelfType], state: Dict[str, Any]) -> SelfType:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def n_ic_timesteps(self) -> int:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def predict_paired(self) -> PredictFunction[PS, FD, SD]:
+ pass
diff --git a/fme/fme/core/generics/trainer.py b/fme/fme/core/generics/trainer.py
new file mode 100644
index 0000000..8e88a13
--- /dev/null
+++ b/fme/fme/core/generics/trainer.py
@@ -0,0 +1,538 @@
+# This module is derived from the train.py module in the following repository:
+# https://github.com/NVlabs/FourCastNet. The corresponding license is
+# provided below.
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2022, FourCastNet authors
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# The code was authored by the following people:
+#
+# Jaideep Pathak - NVIDIA Corporation
+# Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
+# Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
+# Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory
+# Ashesh Chattopadhyay - Rice University
+# Morteza Mardani - NVIDIA Corporation
+# Thorsten Kurth - NVIDIA Corporation
+# David Hall - NVIDIA Corporation
+# Zongyi Li - California Institute of Technology, NVIDIA Corporation
+# Kamyar Azizzadenesheli - Purdue University
+# Pedram Hassanzadeh - Rice University
+# Karthik Kashinath - NVIDIA Corporation
+# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
+
+import abc
+import contextlib
+import gc
+import logging
+import os
+import time
+import uuid
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Protocol,
+ TypeVar,
+)
+
+import torch
+
+import fme
+from fme.core.distributed import Distributed
+from fme.core.ema import EMATracker
+from fme.core.generics.aggregator import AggregatorABC, InferenceAggregatorABC
+from fme.core.generics.data import GriddedDataABC, InferenceDataABC
+from fme.core.generics.inference import run_inference
+from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC
+from fme.core.optimization import NullOptimization, Optimization
+from fme.core.timing import GlobalTimer
+from fme.core.typing_ import Slice
+from fme.core.wandb import WandB
+
+
+class EndOfBatchCallback(Protocol):
+ def __call__(self) -> None: ...
+
+
+class TrainConfigProtocol(Protocol):
+ __dataclass_fields__: ClassVar[Dict[str, Any]]
+
+ @property
+ def experiment_dir(self) -> str: ...
+
+ @property
+ def checkpoint_dir(self) -> str: ...
+
+ @property
+ def max_epochs(self) -> int: ...
+
+ @property
+ def save_checkpoint(self) -> bool: ...
+
+ @property
+ def validate_using_ema(self) -> bool: ...
+
+ @property
+ def log_train_every_n_batches(self) -> int: ...
+
+ @property
+ def segment_epochs(self) -> Optional[int]: ...
+
+ @property
+ def checkpoint_save_epochs(self) -> Optional[Slice]: ...
+
+ @property
+ def ema_checkpoint_save_epochs(self) -> Optional[Slice]: ...
+
+ def clean_wandb(self, experiment_dir: str) -> None: ...
+
+ def get_inference_epochs(self) -> List[int]: ...
+
+
+PS = TypeVar("PS", contravariant=True) # prognostic state
+TO = TypeVar("TO", bound="TrainOutputABC") # train output
+BD = TypeVar("BD") # batch data for training
+FD = TypeVar("FD") # forcing data for inference
+SD = TypeVar("SD") # stepped data from inference
+
+
+class AggregatorBuilderABC(abc.ABC, Generic[PS, TO, SD]):
+ @abc.abstractmethod
+ def get_train_aggregator(self) -> AggregatorABC[TO]:
+ pass
+
+ @abc.abstractmethod
+ def get_validation_aggregator(self) -> AggregatorABC[TO]:
+ pass
+
+ @abc.abstractmethod
+ def get_inference_aggregator(self) -> InferenceAggregatorABC[PS, SD]:
+ pass
+
+
+class CheckpointPaths:
+ def __init__(self, checkpoint_dir: str):
+ self.checkpoint_dir = checkpoint_dir
+
+ @property
+ def latest_checkpoint_path(self) -> str:
+ return os.path.join(self.checkpoint_dir, "ckpt.tar")
+
+ @property
+ def best_checkpoint_path(self) -> str:
+ return os.path.join(self.checkpoint_dir, "best_ckpt.tar")
+
+ @property
+ def best_inference_checkpoint_path(self) -> str:
+ return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar")
+
+ @property
+ def ema_checkpoint_path(self) -> str:
+ return os.path.join(self.checkpoint_dir, "ema_ckpt.tar")
+
+ def epoch_checkpoint_path(self, epoch: int) -> str:
+ return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar")
+
+ def ema_epoch_checkpoint_path(self, epoch: int) -> str:
+ return os.path.join(self.checkpoint_dir, f"ema_ckpt_{epoch:04d}.tar")
+
+
+class Trainer:
+ def __init__(
+ self,
+ train_data: GriddedDataABC[BD],
+ validation_data: GriddedDataABC[BD],
+ inference_data: InferenceDataABC[PS, FD],
+ stepper: TrainStepperABC[PS, BD, FD, SD, TO],
+ build_optimization: Callable[[torch.nn.ModuleList], Optimization],
+ build_ema: Callable[[torch.nn.ModuleList], EMATracker],
+ config: TrainConfigProtocol,
+ aggregator_builder: AggregatorBuilderABC[PS, TO, SD],
+ end_of_batch_callback: EndOfBatchCallback = lambda: None,
+ do_gc_collect: bool = True,
+ ):
+ logging.info(f"Current device is {fme.get_device()}")
+ self.dist = Distributed.get_instance()
+ if self.dist.is_root():
+ if not os.path.isdir(config.experiment_dir):
+ os.makedirs(config.experiment_dir)
+ if not os.path.isdir(config.checkpoint_dir):
+ os.makedirs(config.checkpoint_dir)
+ self.config = config
+ self.paths = CheckpointPaths(config.checkpoint_dir)
+
+ logging.info("rank %d, begin data loader init" % self.dist.rank)
+ self.train_data = train_data
+ self.valid_data = validation_data
+ logging.info("rank %d, data loader initialized" % self.dist.rank)
+ for gridded_data, name in zip(
+ (self.train_data, self.valid_data), ("train", "valid")
+ ):
+ gridded_data.log_info(name)
+
+ self.num_batches_seen = 0
+ self._start_epoch = 0
+ self._model_epoch = self._start_epoch
+ self.num_batches_seen = 0
+ self._best_validation_loss = torch.inf
+ self._best_inference_error = torch.inf
+
+ self.stepper = stepper
+ self.optimization = build_optimization(stepper.modules)
+ self._end_of_batch_ops = end_of_batch_callback
+ self._no_optimization = NullOptimization()
+ self._aggregator_builder = aggregator_builder
+
+ resuming = os.path.isfile(self.paths.latest_checkpoint_path)
+ if resuming:
+ logging.info(f"Resuming training from {self.paths.latest_checkpoint_path}")
+ self.restore_checkpoint(
+ self.paths.latest_checkpoint_path, self.paths.ema_checkpoint_path
+ )
+
+ wandb = WandB.get_instance()
+ wandb.watch(self.stepper.modules)
+
+ logging.info(
+ (
+ "Number of trainable model parameters: "
+ f"{count_parameters(self.stepper.modules)}"
+ )
+ )
+
+ self._inference_data = inference_data
+ self._ema = build_ema(stepper.modules)
+ self._do_gc_collect = do_gc_collect
+
+ def switch_off_grad(self, model: torch.nn.Module):
+ for param in model.parameters():
+ param.requires_grad = False
+
+ def train(self):
+ logging.info("Starting Training Loop...")
+
+ self._model_epoch = self._start_epoch
+ inference_epochs = self.config.get_inference_epochs()
+ if self.config.segment_epochs is None:
+ segment_max_epochs = self.config.max_epochs
+ else:
+ segment_max_epochs = min(
+ self._start_epoch + self.config.segment_epochs, self.config.max_epochs
+ )
+ # "epoch" describes the loop, self._model_epoch describes model weights
+ # needed so we can describe the loop even after weights are updated
+ for epoch in range(self._start_epoch, segment_max_epochs):
+ if self._do_gc_collect:
+ # garbage collect to avoid CUDA error in some contexts
+ # https://github.com/pytorch/pytorch/issues/67978#issuecomment-1661986812 # noqa: E501
+ gc.collect()
+ logging.info(f"Epoch: {epoch+1}")
+ self.train_data.set_epoch(epoch)
+
+ start_time = time.time()
+ logging.info(f"Starting training step on epoch {epoch + 1}")
+ train_logs = self.train_one_epoch()
+ train_end = time.time()
+ logging.info(f"Starting validation step on epoch {epoch + 1}")
+ valid_logs = self.validate_one_epoch()
+ valid_end = time.time()
+ if epoch in inference_epochs:
+ logging.info(f"Starting inference step on epoch {epoch + 1}")
+ inference_logs = self.inference_one_epoch()
+ inference_end: Optional[float] = time.time()
+ else:
+ inference_logs = {}
+ inference_end = None
+
+ train_loss = train_logs["train/mean/loss"]
+ valid_loss = valid_logs["val/mean/loss"]
+ inference_error = inference_logs.get(
+ "inference/time_mean_norm/rmse/channel_mean", None
+ )
+ # need to get the learning rate before stepping the scheduler
+ lr = self.optimization.learning_rate
+ self.optimization.step_scheduler(valid_loss)
+
+ if self.dist.is_root():
+ if self.config.save_checkpoint:
+ logging.info(f"Saving checkpoints for epoch {epoch + 1}")
+ self.save_all_checkpoints(valid_loss, inference_error)
+
+ time_elapsed = time.time() - start_time
+ logging.info(f"Time taken for epoch {epoch + 1} is {time_elapsed} sec")
+ logging.info(f"Train loss: {train_loss}. Valid loss: {valid_loss}")
+ if inference_error is not None:
+ logging.info(f"Inference error: {inference_error}")
+
+ logging.info("Logging to wandb")
+ all_logs = {
+ **train_logs,
+ **valid_logs,
+ **inference_logs,
+ **{
+ "lr": lr,
+ "epoch": epoch,
+ "epoch_train_seconds": train_end - start_time,
+ "epoch_validation_seconds": valid_end - train_end,
+ "epoch_total_seconds": time_elapsed,
+ },
+ }
+ if inference_end is not None:
+ all_logs["epoch_inference_seconds"] = inference_end - valid_end
+ wandb = WandB.get_instance()
+ wandb.log(all_logs, step=self.num_batches_seen)
+ if segment_max_epochs == self.config.max_epochs:
+ self.config.clean_wandb(experiment_dir=self.config.experiment_dir)
+
+ def train_one_epoch(self):
+ """Train for one epoch and return logs from TrainAggregator."""
+ wandb = WandB.get_instance()
+ aggregator = self._aggregator_builder.get_train_aggregator()
+ n_samples_seen_since_logging = 0
+ if self.num_batches_seen == 0:
+ # Before training, log the loss on the first batch.
+ with torch.no_grad(), GlobalTimer():
+ batch = next(iter(self.train_data.loader))
+ stepped = self.stepper.train_on_batch(
+ batch,
+ optimization=self._no_optimization,
+ )
+
+ if self.config.log_train_every_n_batches > 0:
+ with torch.no_grad():
+ metrics = {
+ f"batch_{name}": self.dist.reduce_mean(metric)
+ for name, metric in sorted(stepped.get_metrics().items())
+ }
+ wandb.log(metrics, step=self.num_batches_seen)
+ current_time = time.time()
+ for batch in self.train_data.loader:
+ with GlobalTimer():
+ stepped = self.stepper.train_on_batch(batch, self.optimization)
+ aggregator.record_batch(stepped)
+ self._end_of_batch_ops()
+ self._ema(model=self.stepper.modules)
+ self.num_batches_seen += 1
+ n_samples_seen_since_logging += self.train_data.batch_size
+ if (
+ self.config.log_train_every_n_batches > 0
+ and self.num_batches_seen % self.config.log_train_every_n_batches == 0
+ ):
+ with torch.no_grad():
+ metrics = {
+ f"batch_{name}": self.dist.reduce_mean(metric)
+ for name, metric in sorted(stepped.get_metrics().items())
+ }
+ duration = time.time() - current_time
+ current_time = time.time()
+ samples_per_second = n_samples_seen_since_logging / duration
+ metrics["training_samples_per_second_on_rank_0"] = samples_per_second
+ wandb.log(metrics, step=self.num_batches_seen)
+ n_samples_seen_since_logging = 0
+ self._model_epoch += 1
+
+ return aggregator.get_logs(label="train")
+
+ @contextlib.contextmanager
+ def _validation_context(self):
+ """
+ The context for running validation.
+
+ In this context, the stepper uses the EMA model if
+ `self.config.validate_using_ema` is True.
+ """
+ if self.config.validate_using_ema:
+ with self._ema_context():
+ yield
+ else:
+ yield
+
+ @contextlib.contextmanager
+ def _ema_context(self):
+ """
+ A context where the stepper uses the EMA model.
+ """
+ self._ema.store(parameters=self.stepper.modules.parameters())
+ self._ema.copy_to(model=self.stepper.modules)
+ try:
+ yield
+ finally:
+ self._ema.restore(parameters=self.stepper.modules.parameters())
+
+ def validate_one_epoch(self):
+ aggregator = self._aggregator_builder.get_validation_aggregator()
+ with torch.no_grad(), self._validation_context(), GlobalTimer():
+ for batch in self.valid_data.loader:
+ stepped = self.stepper.train_on_batch(
+ batch,
+ optimization=NullOptimization(),
+ compute_derived_variables=True,
+ )
+ aggregator.record_batch(
+ batch=stepped,
+ )
+ return aggregator.get_logs(label="val")
+
+ def inference_one_epoch(self):
+ aggregator = self._aggregator_builder.get_inference_aggregator()
+ with torch.no_grad(), self._validation_context(), GlobalTimer():
+ run_inference(
+ predict=self.stepper.predict_paired,
+ data=self._inference_data,
+ aggregator=aggregator,
+ )
+ logs = aggregator.get_summary_logs()
+ return {f"inference/{k}": v for k, v in logs.items()}
+
+ def save_checkpoint(self, checkpoint_path):
+ # save to a temporary file in case we get pre-empted during save
+ temporary_location = os.path.join(
+ os.path.dirname(checkpoint_path), f".{uuid.uuid4()}.tmp"
+ )
+ try:
+ torch.save(
+ {
+ "num_batches_seen": self.num_batches_seen,
+ "epoch": self._model_epoch,
+ "best_validation_loss": self._best_validation_loss,
+ "best_inference_error": self._best_inference_error,
+ "stepper": self.stepper.get_state(),
+ "optimization": self.optimization.get_state(),
+ "ema": self._ema.get_state(),
+ },
+ temporary_location,
+ )
+ os.replace(temporary_location, checkpoint_path)
+ finally:
+ if os.path.exists(temporary_location):
+ os.remove(temporary_location)
+
+ def restore_checkpoint(self, checkpoint_path, ema_checkpoint_path):
+ _restore_checkpoint(self, checkpoint_path, ema_checkpoint_path)
+
+ def _epoch_checkpoint_enabled(self, epoch: int) -> bool:
+ return epoch_checkpoint_enabled(
+ epoch, self.config.max_epochs, self.config.checkpoint_save_epochs
+ )
+
+ def _ema_epoch_checkpoint_enabled(self, epoch: int) -> bool:
+ return epoch_checkpoint_enabled(
+ epoch, self.config.max_epochs, self.config.ema_checkpoint_save_epochs
+ )
+
+ def save_all_checkpoints(self, valid_loss: float, inference_error: Optional[float]):
+ if self.config.validate_using_ema:
+ best_checkpoint_context = self._ema_context
+ else:
+ best_checkpoint_context = contextlib.nullcontext # type: ignore
+ with best_checkpoint_context():
+ save_best_checkpoint = False
+ if valid_loss <= self._best_validation_loss:
+ logging.info(
+ "Saving lowest validation loss checkpoint to "
+ f"{self.paths.best_checkpoint_path}"
+ )
+ self._best_validation_loss = valid_loss
+ save_best_checkpoint = True # wait until inference error is updated
+ if inference_error is not None and (
+ inference_error <= self._best_inference_error
+ ):
+ logging.info(
+ f"Epoch inference error ({inference_error}) is lower than "
+ f"previous best inference error ({self._best_inference_error})."
+ )
+ logging.info(
+ "Saving lowest inference error checkpoint to "
+ f"{self.paths.best_inference_checkpoint_path}"
+ )
+ self._best_inference_error = inference_error
+ self.save_checkpoint(self.paths.best_inference_checkpoint_path)
+ if save_best_checkpoint:
+ self.save_checkpoint(self.paths.best_checkpoint_path)
+
+ logging.info(f"Saving latest checkpoint to {self.paths.latest_checkpoint_path}")
+ self.save_checkpoint(self.paths.latest_checkpoint_path)
+ with self._ema_context():
+ logging.info(
+ f"Saving latest EMA checkpoint to {self.paths.ema_checkpoint_path}"
+ )
+ self.save_checkpoint(self.paths.ema_checkpoint_path)
+ if self._epoch_checkpoint_enabled(self._model_epoch):
+ epoch_checkpoint_path = self.paths.epoch_checkpoint_path(self._model_epoch)
+ logging.info(f"Saving epoch checkpoint to {epoch_checkpoint_path}")
+ self.save_checkpoint(epoch_checkpoint_path)
+ if self._ema_epoch_checkpoint_enabled(self._model_epoch):
+ ema_epoch_checkpoint_path = self.paths.ema_epoch_checkpoint_path(
+ self._model_epoch
+ )
+ logging.info(f"Saving EMA epoch checkpoint to {ema_epoch_checkpoint_path}")
+ with self._ema_context():
+ self.save_checkpoint(ema_epoch_checkpoint_path)
+
+
+def _restore_checkpoint(trainer: Trainer, checkpoint_path, ema_checkpoint_path):
+ # separated into a function only to make it easier to mock
+ checkpoint = torch.load(checkpoint_path, map_location=fme.get_device())
+ # restore checkpoint is used for finetuning as well as resuming.
+ # If finetuning (i.e., not resuming), restore checkpoint
+ # does not load optimizer state, instead uses config specified lr.
+ trainer.stepper.load_state(checkpoint["stepper"])
+ trainer.optimization.load_state(checkpoint["optimization"])
+ trainer.num_batches_seen = checkpoint["num_batches_seen"]
+ trainer._start_epoch = checkpoint["epoch"]
+ trainer._best_validation_loss = checkpoint["best_validation_loss"]
+ trainer._best_inference_error = checkpoint["best_inference_error"]
+ ema_checkpoint = torch.load(ema_checkpoint_path, map_location=fme.get_device())
+ ema_stepper: TrainStepperABC = type(trainer.stepper).from_state(
+ ema_checkpoint["stepper"]
+ )
+ trainer._ema = EMATracker.from_state(checkpoint["ema"], ema_stepper.modules)
+
+
+def count_parameters(modules: torch.nn.ModuleList) -> int:
+ parameters = 0
+ for module in modules:
+ for parameter in module.parameters():
+ if parameter.requires_grad:
+ parameters += parameter.numel()
+ return parameters
+
+
+def epoch_checkpoint_enabled(
+ epoch: int, max_epochs: int, save_epochs: Optional[Slice]
+) -> bool:
+ if save_epochs is None:
+ return False
+ return epoch in range(max_epochs)[save_epochs.slice]
diff --git a/fme/fme/core/generics/writer.py b/fme/fme/core/generics/writer.py
new file mode 100644
index 0000000..1692503
--- /dev/null
+++ b/fme/fme/core/generics/writer.py
@@ -0,0 +1,46 @@
+import abc
+from typing import Any, Generic, TypeVar
+
+PS = TypeVar("PS", contravariant=True) # prognostic state
+SD = TypeVar("SD", contravariant=True) # stepped data
+
+
+class WriterABC(abc.ABC, Generic[PS, SD]):
+ @abc.abstractmethod
+ def write(self, data: PS, filename: str):
+ """Eagerly write data to a file at filename."""
+ ...
+
+ @abc.abstractmethod
+ def append_batch(
+ self,
+ batch: SD,
+ ):
+ """
+ Append a batch of data to the output file(s).
+
+ Args:
+ batch: Data to be written.
+ """
+ ...
+
+
+class NullDataWriter(WriterABC[Any, Any]):
+ """
+ Null pattern for DataWriter, which does nothing.
+ """
+
+ def __init__(self):
+ pass
+
+ def append_batch(
+ self,
+ batch: Any,
+ ):
+ pass
+
+ def flush(self):
+ pass
+
+ def write(self, data: Any, filename: str):
+ pass
diff --git a/fme/fme/core/gridded_ops.py b/fme/fme/core/gridded_ops.py
new file mode 100644
index 0000000..d2ef53f
--- /dev/null
+++ b/fme/fme/core/gridded_ops.py
@@ -0,0 +1,144 @@
+import abc
+from typing import Any, Dict, List, Type, TypeVar
+
+import torch
+
+from fme.core import metrics
+from fme.core.device import get_device
+
+
+class GriddedOperations(abc.ABC):
+ @abc.abstractmethod
+ def area_weighted_mean(
+ self, data: torch.Tensor, keepdim: bool = False
+ ) -> torch.Tensor: ...
+
+ def area_weighted_mean_bias(
+ self, truth: torch.Tensor, predicted: torch.Tensor
+ ) -> torch.Tensor:
+ return self.area_weighted_mean(predicted - truth)
+
+ def area_weighted_rmse(
+ self, truth: torch.Tensor, predicted: torch.Tensor
+ ) -> torch.Tensor:
+ return torch.sqrt(self.area_weighted_mean((predicted - truth) ** 2))
+
+ def area_weighted_std(self, data: torch.Tensor, keepdim: bool = False):
+ return self.area_weighted_mean(
+ (data - self.area_weighted_mean(data, keepdim=True)) ** 2,
+ keepdim=keepdim,
+ ).sqrt()
+
+ @abc.abstractmethod
+ def area_weighted_gradient_magnitude_percent_diff(
+ self, truth: torch.Tensor, predicted: torch.Tensor
+ ): ...
+
+ def to_state(self) -> Dict[str, Any]:
+ return {
+ "type": self.__class__.__name__,
+ "state": self.get_initialization_kwargs(),
+ }
+
+ @abc.abstractmethod
+ def get_initialization_kwargs(self) -> Dict[str, Any]:
+ """
+ Get the keyword arguments needed to initialize the instance.
+ """
+ ...
+
+ @classmethod
+ def from_state(cls, state: Dict[str, Any]) -> "GriddedOperations":
+ """
+ Given a dictionary with a "type" key and a "state" key, return
+ the GriddedOperations it describes.
+
+ The "type" key should be the name of a subclass of GriddedOperations,
+ and the "state" key should be a dictionary specific to
+ that subclass.
+
+ Args:
+ state: A dictionary with a "type" key and a "state" key.
+
+ Returns:
+ An instance of the subclass.
+ """
+ if cls is not GriddedOperations:
+ raise RuntimeError(
+ "This method should be called on GriddedOperations, "
+ "not on its subclasses."
+ )
+ subclasses = get_all_subclasses(cls)
+ for subclass in subclasses:
+ if subclass.__name__ == state["type"]:
+ return subclass(**state["state"])
+ raise ValueError(
+ f"Unknown subclass type: {state['type']}, "
+ f"available: {[s.__name__ for s in subclasses]}"
+ )
+
+
+T = TypeVar("T")
+
+
+def get_all_subclasses(cls: Type[T]) -> List[Type[T]]:
+ """
+ Gets all subclasses of a given class, including their subclasses etc.
+ """
+ all_subclasses = []
+ for subclass in cls.__subclasses__():
+ all_subclasses.append(subclass)
+ all_subclasses.extend(get_all_subclasses(subclass))
+ return all_subclasses
+
+
+class LatLonOperations(GriddedOperations):
+ HORIZONTAL_DIMS = (-2, -1)
+
+ def __init__(self, area_weights: torch.Tensor):
+ self._device_area = area_weights.to(get_device())
+ self._cpu_area = area_weights.to("cpu")
+
+ def area_weighted_mean(
+ self, data: torch.Tensor, keepdim: bool = False
+ ) -> torch.Tensor:
+ if data.device.type == "cpu":
+ area_weights = self._cpu_area
+ else:
+ area_weights = self._device_area
+ return metrics.weighted_mean(
+ data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim
+ )
+
+ def area_weighted_gradient_magnitude_percent_diff(
+ self, truth: torch.Tensor, predicted: torch.Tensor
+ ):
+ if predicted.device.type == "cpu":
+ area_weights = self._cpu_area
+ else:
+ area_weights = self._device_area
+ return metrics.gradient_magnitude_percent_diff(
+ truth, predicted, weights=area_weights, dim=self.HORIZONTAL_DIMS
+ )
+
+ def get_initialization_kwargs(self) -> Dict[str, Any]:
+ return {"area_weights": self._cpu_area}
+
+
+class HEALPixOperations(GriddedOperations):
+ HORIZONTAL_DIMS = (-3, -2, -1)
+
+ def area_weighted_mean(
+ self, data: torch.Tensor, keepdim: bool = False
+ ) -> torch.Tensor:
+ return data.mean(dim=self.HORIZONTAL_DIMS, keepdim=keepdim)
+
+ def area_weighted_gradient_magnitude_percent_diff(
+ self, truth: torch.Tensor, predicted: torch.Tensor
+ ):
+ return metrics.gradient_magnitude_percent_diff(
+ truth, predicted, weights=None, dim=self.HORIZONTAL_DIMS
+ )
+
+ def get_initialization_kwargs(self) -> Dict[str, Any]:
+ return {}
diff --git a/fme/fme/core/histogram.py b/fme/fme/core/histogram.py
index 8cf4b45..91d9c97 100644
--- a/fme/fme/core/histogram.py
+++ b/fme/fme/core/histogram.py
@@ -1,7 +1,7 @@
import collections
import logging
from collections import namedtuple
-from typing import Dict, List, Literal, Mapping, Optional, Tuple
+from typing import Dict, List, Literal, Mapping, Optional, Tuple, Union
import matplotlib.figure
import matplotlib.pyplot as plt
@@ -128,6 +128,8 @@ def _double_size_left(self):
Double the sizes of bins, extending the histogram
to the left (further negative).
"""
+ if self.bin_edges is None:
+ raise RuntimeError("Cannot double size of bins without bin edges")
current_range = self.bin_edges[-1] - self.bin_edges[0]
new_range = 2 * current_range
@@ -147,6 +149,8 @@ def _double_size_right(self):
Double the sizes of bins, extending the histogram
to the right (further positive).
"""
+ if self.bin_edges is None:
+ raise RuntimeError("Cannot double size of bins without bin edges")
current_range = self.bin_edges[-1] - self.bin_edges[0]
new_range = 2 * current_range
@@ -167,7 +171,8 @@ def _double_size_right(self):
class ComparedDynamicHistograms:
"""Wrapper of DynamicHistogram for multiple histograms, two histograms per
- variable plotted on the same axis."""
+ variable plotted on the same axis.
+ """
def __init__(self, n_bins: int, percentiles: Optional[List[float]] = None) -> None:
self.n_bins = n_bins
@@ -205,9 +210,9 @@ def _get_histograms(
) -> Dict[str, Dict[Literal["target", "prediction"], _Histogram]]:
if self.target_histograms is None or self.prediction_histograms is None:
raise ValueError("No data has been added to the histogram")
- return_dict: Dict[
- str, Dict[Literal["target", "prediction"], _Histogram]
- ] = collections.defaultdict(dict)
+ return_dict: Dict[str, Dict[Literal["target", "prediction"], _Histogram]] = (
+ collections.defaultdict(dict)
+ )
for k in self.target_histograms:
counts, bin_edges = trim_zero_bins(
self.target_histograms[k].counts.squeeze(self._time_dim),
@@ -250,7 +255,7 @@ def _plot_histogram(
return fig
def get_wandb(self) -> Dict[str, float]:
- return_dict: Dict[str, float] = {}
+ return_dict: Dict[str, Union[matplotlib.figure.Figure, float]] = {}
for field_name, histograms in self._get_histograms().items():
target = histograms.get("target")
@@ -309,9 +314,9 @@ def get_dataset(self) -> xr.Dataset:
np.zeros_like(target_dataset[missing_prediction_name]),
dims=("bin",),
)
- prediction_dataset[
- f"{missing_prediction_name}_bin_edges"
- ] = target_dataset[f"{missing_prediction_name}_bin_edges"]
+ prediction_dataset[f"{missing_prediction_name}_bin_edges"] = (
+ target_dataset[f"{missing_prediction_name}_bin_edges"]
+ )
ds = xr.concat([target_dataset, prediction_dataset], dim="source")
ds["source"] = ["target", "prediction"]
return ds
diff --git a/fme/fme/core/logging_utils.py b/fme/fme/core/logging_utils.py
index 56bda8d..c33e1e2 100644
--- a/fme/fme/core/logging_utils.py
+++ b/fme/fme/core/logging_utils.py
@@ -25,7 +25,7 @@ class LoggingConfig:
"""
Configuration for logging.
- Attributes:
+ Parameters:
project: name of the project in Weights & Biases
entity: name of the entity in Weights & Biases
log_to_screen: whether to log to the screen
diff --git a/fme/fme/core/loss.py b/fme/fme/core/loss.py
index 3af0fe3..8497f0f 100644
--- a/fme/fme/core/loss.py
+++ b/fme/fme/core/loss.py
@@ -1,15 +1,12 @@
import dataclasses
-from typing import Any, Dict, List, Literal, Mapping, Optional
+from typing import Any, Callable, Dict, List, Literal, Mapping, Optional
import torch
import torch.linalg
-from fme.core.data_loading.data_typing import SigmaCoordinates
from fme.core.device import get_device
from fme.core.packer import Packer
-from fme.core.typing_ import TensorDict, TensorMapping
-
-from .climate_data import ClimateData, compute_dry_air_absolute_differences
+from fme.core.typing_ import TensorDict
class NaNLoss(torch.nn.Module):
@@ -37,27 +34,6 @@ def __call__(
return self.loss(predict_tensors, target_tensors)
-def get_dry_air_nonconservation(
- data: TensorMapping,
- area_weights: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
-):
- """
- Computes the time-average one-step absolute difference in surface pressure due to
- changes in globally integrated dry air.
-
- Args:
- data: A mapping from variable name to tensor of shape
- [sample, time, lat, lon], in physical units. specific_total_water in kg/kg
- and surface_pressure in Pa must be present.
- area_weights: The area of each grid cell as a [lat, lon] tensor, in m^2.
- sigma_coordinates: The sigma coordinates of the model.
- """
- return compute_dry_air_absolute_differences(
- ClimateData(data), area=area_weights, sigma_coordinates=sigma_coordinates
- ).mean()
-
-
def _construct_weight_tensor(
weights: Dict[str, float],
out_names: List[str],
@@ -76,7 +52,6 @@ def _construct_weight_tensor(
n_dim: number of dimensions of the output tensor
channel_dim: the channel dimension of the output tensor
"""
-
missing_keys = set(weights.keys()) - set(out_names)
if len(missing_keys) > 0:
raise KeyError(
@@ -120,12 +95,12 @@ def __call__(self, x, y):
class AreaWeightedMSELoss(torch.nn.Module):
- def __init__(self, area: torch.Tensor):
+ def __init__(self, area_weighted_mean: Callable[[torch.Tensor], torch.Tensor]):
super(AreaWeightedMSELoss, self).__init__()
- self._area_weights = area / area.mean()
+ self._area_weighted_mean = area_weighted_mean
def __call__(self, x, y):
- return torch.mean((x - y) ** 2 * self._area_weights)
+ return torch.mean(self._area_weighted_mean((x - y) ** 2))
class WeightedSum(torch.nn.Module):
@@ -157,17 +132,21 @@ class GlobalMeanLoss(torch.nn.Module):
A module which computes a loss on the global mean of each sample.
"""
- def __init__(self, area: torch.Tensor, loss: torch.nn.Module):
+ def __init__(
+ self,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ loss: torch.nn.Module,
+ ):
"""
Args:
- area: A tensor of shape (n_lat, n_lon) containing the area of
- each grid cell.
+ area_weighted_mean: Computes an area-weighted mean, removing the
+ horizontal dimensions.
loss: A loss function which takes two tensors of shape
(n_samples, n_timesteps, n_channels) and returns a scalar
tensor.
"""
super().__init__()
- self.global_mean = GlobalMean(area)
+ self.global_mean = GlobalMean(area_weighted_mean)
self.loss = loss
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -177,21 +156,22 @@ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
class GlobalMean(torch.nn.Module):
- def __init__(self, area: torch.Tensor):
+ def __init__(self, area_weighted_mean: Callable[[torch.Tensor], torch.Tensor]):
"""
Args:
- area: A tensor of shape (n_lat, n_lon) containing the area of
- each grid cell.
+ area_weighted_mean: Computes an area-weighted mean, removing the
+ horizontal dimensions.
"""
super().__init__()
- self.area_weights = area / area.sum()
+ self._area_weighted_mean = area_weighted_mean
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
- x: A tensor of shape (n_samples, n_timesteps, n_channels, n_lat, n_lon)
+ x: A tensor with spatial dimensions in shape (n_samples, n_timesteps,
+ n_channels, n_lat, n_lon).
"""
- return (x * self.area_weights[None, None, None, :, :]).sum(dim=(3, 4))
+ return self._area_weighted_mean(x)
class VariableWeightingLoss(torch.nn.Module):
@@ -200,6 +180,7 @@ def __init__(self, weights: torch.Tensor, loss: torch.nn.Module):
Args:
weights: A tensor of shape (n_samples, n_channels, n_lat, n_lon)
containing the weights to apply to each channel.
+ loss: A loss function which takes two tensors.
"""
super().__init__()
self.loss = loss
@@ -240,11 +221,18 @@ def __post_init__(self):
if self.global_mean_type is not None and self.global_mean_type != "LpLoss":
raise NotImplementedError(self.global_mean_type)
- def build(self, area: torch.Tensor, reduction: Literal["mean", "none"]) -> Any:
+ def build(
+ self,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ reduction: Literal["mean", "none"],
+ ) -> Any:
"""
Args:
- area: A tensor of shape (n_lat, n_lon) containing the area of
- each grid cell.
+ area_weighted_mean: Computes an area-weighted mean, removing the
+ horizontal dimensions. Only used if the loss function is
+ AreaWeightedMSE.
+ reduction: The reduction to apply to the loss, either "mean" or "none".
+ Only used if the loss function is L1, MSE, or LpLoss.
"""
if self.type == "LpLoss":
main_loss = LpLoss(**self.kwargs)
@@ -253,13 +241,14 @@ def build(self, area: torch.Tensor, reduction: Literal["mean", "none"]) -> Any:
elif self.type == "MSE":
main_loss = torch.nn.MSELoss(reduction=reduction)
elif self.type == "AreaWeightedMSE":
- main_loss = AreaWeightedMSELoss(area)
+ main_loss = AreaWeightedMSELoss(area_weighted_mean)
elif self.type == "NaN":
main_loss = NaNLoss()
if self.global_mean_type is not None:
global_mean_loss = GlobalMeanLoss(
- area=area, loss=LpLoss(**self.global_mean_kwargs)
+ area_weighted_mean=area_weighted_mean,
+ loss=LpLoss(**self.global_mean_kwargs),
)
final_loss = WeightedSum(
modules=[main_loss, global_mean_loss],
@@ -312,9 +301,12 @@ def __post_init__(self):
)
def build(
- self, area: torch.Tensor, out_names: List[str], channel_dim: int = -3
+ self,
+ area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
+ out_names: List[str],
+ channel_dim: int = -3,
) -> Any:
- loss = self.loss_config.build(area, reduction="mean")
+ loss = self.loss_config.build(area_weighted_mean, reduction="mean")
weighted_loss = VariableWeightingLoss(
weights=_construct_weight_tensor(
self.weights, out_names, channel_dim=channel_dim
diff --git a/fme/fme/core/masking.py b/fme/fme/core/masking.py
new file mode 100644
index 0000000..d173512
--- /dev/null
+++ b/fme/fme/core/masking.py
@@ -0,0 +1,129 @@
+import dataclasses
+from typing import Optional
+
+import torch
+
+from fme.core.stacker import Stacker, unstack
+from fme.core.typing_ import TensorDict, TensorMapping
+
+
+def replace_on_mask(
+ original: torch.Tensor,
+ replacement: torch.Tensor,
+ mask: torch.Tensor,
+ mask_value: int,
+):
+ """Replace original with replacement in masked regions.
+
+ Args:
+ original: The original data tensor.
+ replacement: The replacement data tensor.
+ mask: The mask tensor.
+ mask_value: The value of the mask variable in the region to be replaced.
+ """
+ rounded_mask = torch.round(mask).to(int)
+ return torch.where(
+ condition=rounded_mask == mask_value,
+ input=replacement,
+ other=original,
+ )
+
+
+@dataclasses.dataclass
+class MaskingConfig:
+ """
+ Configuration for applying masking to the generated output.
+
+ Parameters:
+ mask_name: The standard name of the mask. May be a prefix (for a 3D masking
+ variable) or a full name (for a 2D masking variable).
+ mask_value: Value of the mask variable in masked regions. Either 0 or 1.
+ fill_value: The constant fill value to use outside of masked regions.
+ surface_mask_name: (optional) The full name of the surface mask. Only required
+ when mask_name is a prefix and separate 2D surface masking is desired.
+ """
+
+ mask_name: str
+ mask_value: int
+ fill_value: float = 0.0
+ surface_mask_name: Optional[str] = None
+
+ def __post_init__(self):
+ if self.mask_value not in [0, 1]:
+ raise ValueError(
+ "mask_value must be either 0 or 1, but got " f"{self.mask_value}"
+ )
+
+ def build(self):
+ return Masking(
+ mask_name=self.mask_name,
+ mask_value=self.mask_value,
+ fill_value=self.fill_value,
+ surface_mask_name=self.surface_mask_name,
+ )
+
+
+class Masking:
+ """Replace masked regions with a fill value."""
+
+ def __init__(
+ self,
+ mask_name: str,
+ mask_value: int,
+ fill_value: float,
+ surface_mask_name: Optional[str] = None,
+ ):
+ self.mask_name = mask_name
+ self.mask_value = mask_value
+ self.fill_value = fill_value
+ self.surface_mask_name = surface_mask_name
+ mask_map = {self.mask_name: [self.mask_name]}
+ if self.surface_mask_name is not None:
+ mask_map[self.surface_mask_name] = [self.surface_mask_name]
+
+ self.mask_stacker = Stacker(mask_map)
+
+ def __call__(
+ self,
+ stacker: Stacker,
+ data: TensorMapping,
+ mask_data: TensorMapping,
+ ) -> TensorDict:
+ """
+ Apply masking to the data for standard names recognized by a stacker.
+
+ Args:
+ stacker: A Stacker for variables to mask in data.
+ data: The data to mask.
+ mask_data: The mask data.
+
+ """
+ mask = self.mask_stacker(self.mask_name, mask_data)
+ if self.surface_mask_name is not None:
+ surface_mask = self.mask_stacker(self.surface_mask_name, mask_data)
+ else:
+ surface_mask = None
+ data_: TensorDict = {**data}
+ for name in stacker.standard_names:
+ stacked = stacker(name, data)
+ if stacked.size(-1) > 1: # 3D masking
+ mask_ = mask
+ elif surface_mask is not None: # 2D masking with surface mask
+ mask_ = surface_mask
+ elif mask.size(-1) == 1: # 2D masking with mask
+ mask_ = mask
+ else:
+ raise RuntimeError(
+ "Masking surface_mask_name is None but the input Stacker "
+ f"includes the 2D standard name {name}."
+ )
+ fill = torch.full_like(stacked, self.fill_value)
+ masked = replace_on_mask(
+ original=stacked,
+ replacement=fill,
+ mask=mask_,
+ mask_value=self.mask_value,
+ )
+ level_names = stacker.get_all_level_names(name, data)
+ data_.update(unstack(masked, level_names, dim=-1))
+ return data_
diff --git a/fme/fme/core/metrics.py b/fme/fme/core/metrics.py
index 7194ecf..b864c4b 100644
--- a/fme/fme/core/metrics.py
+++ b/fme/fme/core/metrics.py
@@ -6,6 +6,7 @@
from typing_extensions import TypeAlias
from fme.core.constants import GRAVITY
+from fme.core.coordinates import HybridSigmaPressureCoordinate
Dimension: TypeAlias = Union[int, Iterable[int]]
Array: TypeAlias = Union[np.ndarray, torch.Tensor]
@@ -114,12 +115,11 @@ def root_mean_squared_error(
dim: Dimension = (),
) -> torch.Tensor:
"""
- Computes the weighted global RMSE over all variables. Namely, for each variable:
+ Compute a weighted root mean square error between truth and predicted.
- sqrt((weights * ((xhat - x) ** 2)).mean(dims))
+ Namely:
- If you want to compute the RMSE over the time dimension, then pass in
- `truth.mean(time_dim)` and `predicted.mean(time_dim)` and specify `dims=space_dims`.
+ sqrt((weights * ((xhat - x) ** 2)).mean(dims))
Args:
truth: torch.Tensor whose last dimensions are to be weighted
@@ -128,7 +128,7 @@ def root_mean_squared_error(
dim: Dimensions to average over.
Returns:
- a tensor of shape (variable,) of weighted RMSEs.
+ A tensor of weighted RMSEs.
"""
assert (
truth.shape == predicted.shape
@@ -157,7 +157,8 @@ def gradient_magnitude_percent_diff(
dim: Dimension = (),
) -> torch.Tensor:
"""Compute the percent difference of the weighted mean gradient magnitude across
- the specified dimensions."""
+ the specified dimensions.
+ """
truth_grad_mag = weighted_mean_gradient_magnitude(truth, weights, dim)
predicted_grad_mag = weighted_mean_gradient_magnitude(predicted, weights, dim)
return 100 * (predicted_grad_mag - truth_grad_mag) / truth_grad_mag
@@ -219,74 +220,23 @@ def time_and_global_mean_bias(
return result
-def vertical_integral(
- integrand: torch.Tensor,
- surface_pressure: torch.Tensor,
- sigma_grid_offsets_ak: torch.Tensor,
- sigma_grid_offsets_bk: torch.Tensor,
-) -> torch.Tensor:
- """Computes a vertical integral, namely:
-
- (1 / g) * ∫ x dp
-
- where
- - g = acceleration due to gravity
- - x = integrad
- - p = pressure level
-
- Args:
- integrand (lat, lon, vertical_level), (kg/kg)
- surface_pressure: (lat, lon), (Pa)
- sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,)
- sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,)
-
- Returns:
- Vertical integral of the integrand (lat, lon).
- """
- ak, bk = sigma_grid_offsets_ak, sigma_grid_offsets_bk
- pressure_thickness = ((ak + (surface_pressure.unsqueeze(-1) * bk))).diff(
- dim=-1
- ) # Pa
- integral = torch.sum(pressure_thickness * integrand, axis=-1) # type: ignore
- return 1 / GRAVITY * integral
-
-
def surface_pressure_due_to_dry_air(
specific_total_water: torch.Tensor,
surface_pressure: torch.Tensor,
- sigma_grid_offsets_ak: torch.Tensor,
- sigma_grid_offsets_bk: torch.Tensor,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
) -> torch.Tensor:
"""Computes the dry air (Pa).
Args:
- specific_total_water (lat, lon, vertical_level), (kg/kg)
- surface_pressure: (lat, lon), (Pa)
- sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,)
- sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,)
+ specific_total_water: last dimension is vertical level (kg/kg)
+ surface_pressure: the surface preessure in Pa.
+ vertical_coordinate: the vertical coordinate for computing vertical integral.
Returns:
- Vertically integrated dry air (lat, lon) (Pa)
+ The surface pressure due to dry air mass only. (Pa)
"""
-
- num_levels = len(sigma_grid_offsets_ak) - 1
-
- if (
- num_levels != len(sigma_grid_offsets_bk) - 1
- or num_levels != specific_total_water.shape[-1]
- ):
- raise ValueError(
- (
- "Number of vertical levels in ak, bk, and specific_total_water must"
- "be the same."
- )
- )
-
- total_water_path = vertical_integral(
- specific_total_water,
- surface_pressure,
- sigma_grid_offsets_ak,
- sigma_grid_offsets_bk,
+ total_water_path = vertical_coordinate.vertical_integral(
+ specific_total_water, surface_pressure
)
dry_air = surface_pressure - GRAVITY * total_water_path
return dry_air
diff --git a/fme/fme/core/normalizer.py b/fme/fme/core/normalizer.py
index 6bc3855..49cebc8 100644
--- a/fme/fme/core/normalizer.py
+++ b/fme/fme/core/normalizer.py
@@ -6,8 +6,8 @@
import torch
import torch.jit
-from fme.core.device import get_device
-from fme.core.typing_ import TensorDict
+from fme.core.device import move_tensordict_to_device
+from fme.core.typing_ import TensorDict, TensorMapping
@dataclasses.dataclass
@@ -18,17 +18,15 @@ class NormalizationConfig:
Either global_means_path and global_stds_path or explicit means and stds
must be provided.
- Attributes:
+ Parameters:
global_means_path: Path to a netCDF file containing global means.
global_stds_path: Path to a netCDF file containing global stds.
- exclude_names: Names to exclude from normalization.
means: Mapping from variable names to means.
stds: Mapping from variable names to stds.
"""
global_means_path: Optional[str] = None
global_stds_path: Optional[str] = None
- exclude_names: Optional[List[str]] = None
means: Mapping[str, float] = dataclasses.field(default_factory=dict)
stds: Mapping[str, float] = dataclasses.field(default_factory=dict)
@@ -49,8 +47,6 @@ def __post_init__(self):
)
def build(self, names: List[str]):
- if self.exclude_names is not None:
- names = list(set(names) - set(self.exclude_names))
using_path = (
self.global_means_path is not None and self.global_stds_path is not None
)
@@ -66,24 +62,6 @@ def build(self, names: List[str]):
return StandardNormalizer(means=means, stds=stds)
-@dataclasses.dataclass
-class FromStateNormalizer:
- """
- An alternative to NormalizationConfig which provides a normalizer
- initialized from a serializable state. This is not a public configuration
- class, but instead allows for loading trained models that have been
- serialized to disk, using the pre-existing normalization state.
-
- Attributes:
- state: State dict of a normalizer.
- """
-
- state: Dict[str, Dict[str, float]]
-
- def build(self, names: List[str]):
- return StandardNormalizer.from_state(self.state)
-
-
class StandardNormalizer:
"""
Responsible for normalizing tensors.
@@ -94,14 +72,17 @@ def __init__(
means: TensorDict,
stds: TensorDict,
):
- self.means = means
- self.stds = stds
+ self.means = move_tensordict_to_device(means)
+ self.stds = move_tensordict_to_device(stds)
+ self._names = set(means).intersection(stds)
- def normalize(self, tensors: TensorDict) -> TensorDict:
- return _normalize(tensors, means=self.means, stds=self.stds)
+ def normalize(self, tensors: TensorMapping) -> TensorDict:
+ filtered_tensors = {k: v for k, v in tensors.items() if k in self._names}
+ return _normalize(filtered_tensors, means=self.means, stds=self.stds)
- def denormalize(self, tensors: TensorDict) -> TensorDict:
- return _denormalize(tensors, means=self.means, stds=self.stds)
+ def denormalize(self, tensors: TensorMapping) -> TensorDict:
+ filtered_tensors = {k: v for k, v in tensors.items() if k in self._names}
+ return _denormalize(filtered_tensors, means=self.means, stds=self.stds)
def get_state(self):
"""
@@ -113,19 +94,15 @@ def get_state(self):
}
@classmethod
- def from_state(self, state) -> "StandardNormalizer":
+ def from_state(cls, state) -> "StandardNormalizer":
"""
Loads state from a serializable data structure.
"""
means = {
- k: torch.tensor(v, device=get_device(), dtype=torch.float)
- for k, v in state["means"].items()
- }
- stds = {
- k: torch.tensor(v, device=get_device(), dtype=torch.float)
- for k, v in state["stds"].items()
+ k: torch.tensor(v, dtype=torch.float) for k, v in state["means"].items()
}
- return StandardNormalizer(means=means, stds=stds)
+ stds = {k: torch.tensor(v, dtype=torch.float) for k, v in state["stds"].items()}
+ return cls(means=means, stds=stds)
@torch.jit.script
@@ -134,10 +111,7 @@ def _normalize(
means: TensorDict,
stds: TensorDict,
) -> TensorDict:
- return {
- k: (t - means[k]) / stds[k] if k in means.keys() else t
- for k, t in tensors.items()
- }
+ return {k: (t - means[k]) / stds[k] for k, t in tensors.items()}
@torch.jit.script
@@ -146,10 +120,7 @@ def _denormalize(
means: TensorDict,
stds: TensorDict,
) -> TensorDict:
- return {
- k: t * stds[k] + means[k] if k in means.keys() else t
- for k, t in tensors.items()
- }
+ return {k: t * stds[k] + means[k] for k, t in tensors.items()}
def get_normalizer(
diff --git a/fme/fme/core/ocean.py b/fme/fme/core/ocean.py
index 21cdf39..67b697c 100644
--- a/fme/fme/core/ocean.py
+++ b/fme/fme/core/ocean.py
@@ -16,7 +16,7 @@ class SlabOceanConfig:
"""
Configuration for a slab ocean model.
- Attributes:
+ Parameters:
mixed_layer_depth_name: Name of the mixed layer depth field.
q_flux_name: Name of the heat flux field.
"""
@@ -34,7 +34,7 @@ class OceanConfig:
"""
Configuration for determining sea surface temperature from an ocean model.
- Attributes:
+ Parameters:
surface_temperature_name: Name of the sea surface temperature field.
ocean_fraction_name: Name of the ocean fraction field.
interpolate: If True, interpolate between ML-predicted surface temperature and
@@ -102,16 +102,16 @@ def __init__(self, config: OceanConfig, timestep: datetime.timedelta):
def __call__(
self,
- target_data: TensorMapping,
input_data: TensorMapping,
gen_data: TensorMapping,
+ target_data: TensorMapping,
) -> TensorDict:
"""
Args:
- target_data: Denormalized data that includes mask and forcing data. Assumed
- to correspond to the same time step as gen_data.
input_data: Denormalized input data for current step.
gen_data: Denormalized output data for current step.
+ target_data: Denormalized data that includes mask and forcing data. Assumed
+ to correspond to the same time step as gen_data.
Returns:
gen_data with sea surface temperature overwritten by ocean model.
diff --git a/fme/fme/core/optimization.py b/fme/fme/core/optimization.py
index 1638be5..2fdb1ca 100644
--- a/fme/fme/core/optimization.py
+++ b/fme/fme/core/optimization.py
@@ -1,18 +1,20 @@
import contextlib
import dataclasses
-from typing import Any, Literal, Mapping, Optional
+import itertools
+from typing import Any, Iterable, Literal, Mapping, Optional
import torch
import torch.cuda.amp as amp
from torch import nn
+from fme.core.generics.optimization import OptimizationABC
from fme.core.scheduler import SchedulerConfig
-class Optimization:
+class Optimization(OptimizationABC):
def __init__(
self,
- parameters,
+ parameters: Iterable[torch.nn.Parameter],
optimizer_type: Literal["Adam", "FusedAdam"],
lr: float,
max_epochs: int,
@@ -114,7 +116,7 @@ class OptimizationConfig:
"""
Configuration for optimization.
- Attributes:
+ Parameters:
optimizer_type: The type of optimizer to use.
lr: The learning rate.
kwargs: Additional keyword arguments to pass to the optimizer.
@@ -132,7 +134,8 @@ class OptimizationConfig:
default_factory=lambda: SchedulerConfig()
)
- def build(self, parameters, max_epochs: int) -> Optimization:
+ def build(self, modules: torch.nn.ModuleList, max_epochs: int) -> Optimization:
+ parameters = itertools.chain(*[module.parameters() for module in modules])
return Optimization(
parameters=parameters,
optimizer_type=self.optimizer_type,
@@ -151,7 +154,7 @@ def from_state(cls, state: Mapping[str, Any]) -> "OptimizationConfig":
return cls(**state)
-class NullOptimization:
+class NullOptimization(OptimizationABC):
@contextlib.contextmanager
def autocast(self):
yield
diff --git a/fme/fme/core/packer.py b/fme/fme/core/packer.py
index 362591e..9d10dca 100644
--- a/fme/fme/core/packer.py
+++ b/fme/fme/core/packer.py
@@ -22,7 +22,7 @@ def __init__(self, names: List[str]):
def pack(self, tensors: TensorDict, axis=0) -> torch.Tensor:
"""
- Packs tensors into a single tensor, concatenated along a new axis
+ Packs tensors into a single tensor, concatenated along a new axis.
Args:
tensors: Dict from names to tensors.
diff --git a/fme/fme/core/parameter_init.py b/fme/fme/core/parameter_init.py
index d49649d..d3fdcd7 100644
--- a/fme/fme/core/parameter_init.py
+++ b/fme/fme/core/parameter_init.py
@@ -23,7 +23,7 @@ class FrozenParameterConfig:
An exception is raised if a parameter is included by both lists.
- Attributes:
+ Parameters:
include: list of parameter names to freeze (set requires_grad = False)
exclude: list of parameter names to ignore
"""
@@ -70,7 +70,7 @@ class ParameterInitializationConfig:
pre-trained model. If the built model has larger weights than the
pre-trained model, only the initial slice of the weights is overwritten.
- Attributes:
+ Parameters:
weight_path: path to a SingleModuleStepper checkpoint
containing weights to load
exclude_parameters: list of parameter names to exclude from the loaded
@@ -123,6 +123,8 @@ def apply(
def regularizer():
return torch.tensor(0.0, device=device)
+ return module, regularizer
+
else:
loaded_state_dict = {
name: value.to(device) for name, value in loaded_state_dict.items()
@@ -136,6 +138,8 @@ def regularizer():
"which is not allowed"
)
+ non_optional_state_dict = loaded_state_dict
+
def regularizer():
loss = torch.tensor(0.0, device=device)
for name in from_names:
@@ -155,7 +159,7 @@ def regularizer():
self.alpha
/ 2
* torch.linalg.norm(
- (param - loaded_state_dict[name]).flatten(),
+ (param - non_optional_state_dict[name]).flatten(),
ord=2,
)
)
diff --git a/fme/fme/core/prescriber.py b/fme/fme/core/prescriber.py
index 3a813a2..d0b1f64 100644
--- a/fme/fme/core/prescriber.py
+++ b/fme/fme/core/prescriber.py
@@ -1,8 +1,7 @@
import dataclasses
from typing import List
-import torch
-
+from fme.core.masking import replace_on_mask
from fme.core.typing_ import TensorDict, TensorMapping
@@ -17,7 +16,7 @@ class PrescriberConfig:
target value at 1 based on the mask variable, and it is assumed the mask variable
lies in the range from 0 to 1.
- Attributes:
+ Parameters:
prescribed_name: Name of the variable to be overwritten.
mask_name: Name of the mask variable.
mask_value: Value of the mask variable in the region to be overwritten.
@@ -78,7 +77,7 @@ def __call__(
) -> TensorDict:
"""
Args:
- data: Dictionary of data containing the mask variable.
+ mask_data: Dictionary of data containing the mask variable.
gen: Dictionary of data to use outside of mask region.
target: Dictionary of data to use in mask region.
@@ -104,11 +103,11 @@ def __call__(
)
else:
# overwrite specified target variable in given mask region
- rounded_mask = torch.round(mask_data[self.mask_name]).to(int)
- output = torch.where(
- condition=rounded_mask == self.mask_value,
- input=target[self.prescribed_name],
- other=gen[self.prescribed_name],
+ output = replace_on_mask(
+ original=gen[self.prescribed_name],
+ replacement=target[self.prescribed_name],
+ mask=mask_data[self.mask_name],
+ mask_value=self.mask_value,
)
return {**gen, self.prescribed_name: output}
diff --git a/fme/fme/core/registry/__init__.py b/fme/fme/core/registry/__init__.py
new file mode 100644
index 0000000..867272f
--- /dev/null
+++ b/fme/fme/core/registry/__init__.py
@@ -0,0 +1,2 @@
+from .corrector import CorrectorSelector
+from .module import ModuleSelector
diff --git a/fme/fme/core/registry/corrector.py b/fme/fme/core/registry/corrector.py
new file mode 100644
index 0000000..a2ac927
--- /dev/null
+++ b/fme/fme/core/registry/corrector.py
@@ -0,0 +1,90 @@
+import dataclasses
+import datetime
+from typing import Any, Callable, ClassVar, Mapping, Type, TypeVar
+
+import dacite
+
+from fme.core.coordinates import HybridSigmaPressureCoordinate
+from fme.core.corrector.registry import CorrectorConfigProtocol
+from fme.core.gridded_ops import GriddedOperations
+
+from .registry import Registry
+
+CT = TypeVar("CT", bound=Type[CorrectorConfigProtocol])
+
+
+@dataclasses.dataclass
+class CorrectorSelector:
+ """
+ A dataclass containing all the information needed to build a
+ CorrectorConfigProtocol, including the type of the CorrectorConfigProtocol
+ and the data needed to build it.
+
+ This is helpful as CorrectorSelector can be serialized and deserialized
+ without any additional information, whereas to load a
+ CorrectorConfigProtocol you would need to know the type of the
+ CorrectorConfigProtocol being loaded.
+
+ It is also convenient because CorrectorSelector is a single class that can
+ be used to represent any CorrectorConfigProtocol, whereas
+ CorrectorConfigProtocol is a protocol that can be implemented by many
+ different classes.
+
+ Parameters:
+ type: the type of the CorrectorConfigProtocol
+ config: data for a CorrectorConfigProtocol instance of the indicated type
+
+ """
+
+ type: str
+ config: Mapping[str, Any]
+ registry: ClassVar[Registry] = Registry()
+
+ def __post__init(self):
+ if self.registry is not Registry():
+ raise ValueError("CorrectorSelector.registry should not be set manually")
+
+ @classmethod
+ def register(cls, type_name) -> Callable[[CT], CT]:
+ return cls.registry.register(type_name)
+
+ def build(
+ self,
+ gridded_operations: GriddedOperations,
+ vertical_coordinate: HybridSigmaPressureCoordinate,
+ timestep: datetime.timedelta,
+ ):
+ instance = self.registry.from_dict(self.get_state())
+ return instance.build(
+ gridded_operations=gridded_operations,
+ vertical_coordinate=vertical_coordinate,
+ timestep=timestep,
+ )
+
+ def get_state(self) -> Mapping[str, Any]:
+ """
+ Get a dictionary containing all the information needed to build a
+ CorrectorConfigProtocol.
+
+ """
+ return {"type": self.type, "config": self.config}
+
+ @classmethod
+ def from_state(cls, state: Mapping[str, Any]) -> "CorrectorSelector":
+ """
+ Create a CorrectorSelector from a dictionary containing all the information
+ needed to build a CorrectorConfigProtocol.
+ """
+ return dacite.from_dict(
+ data_class=cls, data=state, config=dacite.Config(strict=True)
+ )
+
+ @classmethod
+ def from_dict(cls, config: dict):
+ instance = cls.registry.from_dict(config)
+ return cls(config=instance, type=config["type"])
+
+ @classmethod
+ def get_available_types(cls):
+ """This class method is used to expose all available types of Correctors."""
+ return cls(type="", config={}).registry._types.keys()
diff --git a/fme/fme/core/registry.py b/fme/fme/core/registry/module.py
similarity index 66%
rename from fme/fme/core/registry.py
rename to fme/fme/core/registry/module.py
index ffe6ec7..29afade 100644
--- a/fme/fme/core/registry.py
+++ b/fme/fme/core/registry/module.py
@@ -1,10 +1,12 @@
import abc
import dataclasses
-from typing import Any, Callable, Dict, Mapping, Tuple, Type
+from typing import Any, Callable, ClassVar, Mapping, Tuple, TypeVar, Union
import dacite
from torch import nn
+from .registry import Registry
+
@dataclasses.dataclass
class ModuleConfig(abc.ABC):
@@ -30,8 +32,8 @@ def build(
Args:
n_in_channels: number of input channels
n_out_channels: number of output channels
- img_shape: last two dimensions of data, corresponding to lat and
- lon when using FourCastNet conventions
+ img_shape: shape of last two dimensions of data, e.g. latitude and
+ longitude.
Returns:
a nn.Module
@@ -49,41 +51,7 @@ def from_state(cls, state: Mapping[str, Any]) -> "ModuleConfig":
)
-NET_REGISTRY: Dict[str, Type[ModuleConfig]] = {}
-
-
-def get_available_module_types():
- return NET_REGISTRY.keys()
-
-
-def register(name: str) -> Callable[[Type[ModuleConfig]], Type[ModuleConfig]]:
- """
- Register a new ModuleConfig type with the NET_REGISTRY.
-
- This is useful for adding new ModuleConfig types to the registry from
- other modules.
-
- Args:
- name: name of the ModuleConfig type to register
-
- Returns:
- a decorator which registers the decorated class with the NET_REGISTRY
- """
- if not isinstance(name, str):
- raise TypeError(
- f"name must be a string, got {name}, "
- "make sure to use as @register('module_name')"
- )
-
- def decorator(cls: Type[ModuleConfig]) -> Type[ModuleConfig]:
- NET_REGISTRY[name] = cls
- return cls
-
- return decorator
-
-
-def get_from_registry(name) -> Type[ModuleConfig]:
- return NET_REGISTRY[name]
+MT = TypeVar("MT", bound=nn.Module)
@dataclasses.dataclass
@@ -100,22 +68,22 @@ class ModuleSelector:
used to represent any ModuleConfig, whereas ModuleConfig is a protocol
that can be implemented by many different classes.
- Attributes:
+ Parameters:
type: the type of the ModuleConfig
config: data for a ModuleConfig instance of the indicated type
"""
type: str
config: Mapping[str, Any]
+ registry: ClassVar[Registry] = Registry()
+
+ def __post__init(self):
+ if self.registry is not Registry():
+ raise ValueError("ModuleSelector.registry should not be set manually")
- def __post_init__(self):
- try:
- self._config = get_from_registry(self.type).from_state(self.config)
- except KeyError:
- raise ValueError(
- f"unknown module type {self.type}, "
- f"known module types are {list(NET_REGISTRY.keys())}"
- )
+ @classmethod
+ def register(cls, type_name) -> Callable[[MT], MT]:
+ return cls.registry.register(type_name)
def build(
self,
@@ -136,13 +104,14 @@ def build(
Returns:
a nn.Module
"""
- return self._config.build(
+ instance = self.registry.from_dict(self.get_state())
+ return instance.build(
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
img_shape=img_shape,
)
- def get_state(self) -> Mapping[str, Any]:
+ def get_state(self) -> Union[Mapping[str, Any], dict]:
"""
Get a dictionary containing all the information needed to build a ModuleConfig.
"""
@@ -155,5 +124,15 @@ def from_state(cls, state: Mapping[str, Any]) -> "ModuleSelector":
needed to build a ModuleConfig.
"""
return dacite.from_dict(
- data_class=ModuleSelector, data=state, config=dacite.Config(strict=True)
+ data_class=cls, data=state, config=dacite.Config(strict=True)
)
+
+ @classmethod
+ def from_dict(cls, config: dict):
+ instance = cls.registry.from_dict(config)
+ return cls(config=instance, type=config["type"])
+
+ @classmethod
+ def get_available_types(cls):
+ """This class method is used to expose all available types of Modules."""
+ return cls(type="", config={}).registry._types.keys()
diff --git a/fme/fme/core/registry/registry.py b/fme/fme/core/registry/registry.py
new file mode 100644
index 0000000..0c8d5a7
--- /dev/null
+++ b/fme/fme/core/registry/registry.py
@@ -0,0 +1,73 @@
+from typing import Any, Callable, Dict, Generic, Mapping, Optional, Type, TypeVar
+
+import dacite
+
+T = TypeVar("T")
+TT = TypeVar("TT", bound=Type)
+
+
+class Registry(Generic[T]):
+ """
+ Used to register and initialize multiple types of a dataclass.
+ """
+
+ def __init__(self, default_type: Optional[str] = None):
+ """
+ Initialize the registry.
+
+ Args:
+ default_type: if given, the "type" key in the config dict is optional
+ and by default this type will be used.
+ """
+ self._types: Dict[str, Type[T]] = {}
+ self.default_type = default_type
+
+ def register(self, type_name: str) -> Callable[[TT], TT]:
+ """
+ Registers a configuration type with the registry.
+
+ When registry.from_dict is called to initialize a dataclass, if the
+ "type" key in that dictionary is equal to the type_name you give here,
+ then the decorated class will be the one initialized from the data
+ in the "config" key.
+
+ Args:
+ type_name: name used in configuration to indicate the decorated
+ class as the target type to be initialized when using from_dict.
+ """
+
+ def register_func(cls: TT) -> TT:
+ self._types[type_name] = cls
+ return cls
+
+ return register_func
+
+ def from_dict(self, config: Mapping[str, Any]) -> T:
+ """
+ Creates a registered type from the given config dict.
+
+ Config should have at least one key, "type", which indicates the type to
+ initialize based on its registered type name. This can be omitted if
+ this instance was initialized with a default type.
+
+ It can also have a "config" key, which is a dict used to initialize the
+ dataclass. By default this is an empty dict.
+ """
+ config = dict(config)
+ config.setdefault("config", {})
+ if self.default_type is not None:
+ type_name = config.get("type", self.default_type)
+ else:
+ type_name = config["type"]
+ if type_name not in self._types:
+ raise ValueError(
+ f"Received unexpected type {type_name}, "
+ f"expected one of {self._types.keys()}"
+ )
+ else:
+ instance = dacite.from_dict(
+ data_class=self._types[type_name],
+ data=config["config"],
+ config=dacite.Config(strict=True),
+ )
+ return instance
diff --git a/fme/fme/core/registry/test_module_registry.py b/fme/fme/core/registry/test_module_registry.py
new file mode 100644
index 0000000..6dba17d
--- /dev/null
+++ b/fme/fme/core/registry/test_module_registry.py
@@ -0,0 +1,38 @@
+import dataclasses
+from typing import Iterable, List, Tuple
+
+import torch
+
+from .module import ModuleConfig, ModuleSelector
+
+
+class MockModule(torch.nn.Module):
+ def __init__(self, param_shapes: Iterable[Tuple[int, ...]]):
+ super().__init__()
+ for i, shape in enumerate(param_shapes):
+ setattr(self, f"param{i}", torch.nn.Parameter(torch.randn(shape)))
+
+
+@ModuleSelector.register("mock")
+@dataclasses.dataclass
+class MockModuleBuilder(ModuleConfig):
+ param_shapes: List[Tuple[int, ...]]
+
+ def build(self, n_in_channels, n_out_channels, img_shape):
+ return MockModule(self.param_shapes)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(state["param_shapes"])
+
+ def get_state(self):
+ return {
+ "param_shapes": self.param_shapes,
+ }
+
+
+def test_register():
+ """Make sure that the registry is working as expected."""
+ selector = ModuleSelector(type="mock", config={"param_shapes": [(1, 2, 3)]})
+ module = selector.build(1, 1, (16, 32))
+ assert isinstance(module, MockModule)
diff --git a/fme/fme/core/regrid.py b/fme/fme/core/regrid.py
new file mode 100644
index 0000000..e69de29
diff --git a/fme/fme/core/scheduler.py b/fme/fme/core/scheduler.py
index 8227af3..cc771ba 100644
--- a/fme/fme/core/scheduler.py
+++ b/fme/fme/core/scheduler.py
@@ -9,7 +9,7 @@ class SchedulerConfig:
"""
Configuration for a scheduler to use during training.
- Attributes:
+ Parameters:
type: Name of scheduler class from torch.optim.lr_scheduler,
no scheduler is used by default.
kwargs: Keyword arguments to pass to the scheduler constructor.
diff --git a/fme/fme/core/stacker.py b/fme/fme/core/stacker.py
new file mode 100644
index 0000000..707c662
--- /dev/null
+++ b/fme/fme/core/stacker.py
@@ -0,0 +1,151 @@
+import re
+from typing import List, Mapping, Union
+
+import torch
+
+from fme.core.typing_ import TensorMapping
+
+
+def natural_sort(alist: List[str]) -> List[str]:
+ """Sort to alphabetical order but with numbers sorted
+ numerically, e.g. a11 comes after a2. See [1] and [2].
+
+ [1] https://stackoverflow.com/questions/11150239/natural-sorting
+ [2] https://en.wikipedia.org/wiki/Natural_sort_order
+ """
+
+ def convert(text: str) -> Union[str, int]:
+ if text.isdigit():
+ return int(text)
+ else:
+ return text.lower()
+
+ def alphanum_key(item: str) -> List[Union[str, int]]:
+ return [convert(c) for c in re.split("([0-9]+)", item)]
+
+ return sorted(alist, key=alphanum_key)
+
+
+def unstack(tensor: torch.Tensor, names: List[str], dim: int = -1) -> TensorMapping:
+ """Unstack a 3D variable to a dictionary of 2D variables.
+
+ Args:
+ tensor: 3D tensor to unstack, such as output by a Stacker.
+ names: List of names in natural order to assign to the unstacked variables.
+ Stacker.get_all_level_names can help in retrieving these names when the
+ input tensor was stacked via a Stacker.
+ dim: Dimension along which to unstack.
+
+ """
+ if len(names) != tensor.size(dim):
+ raise ValueError(
+ f"Received {len(names)} names, but 3D tensor has "
+ f"{tensor.size(-1)} levels."
+ )
+ if len(names) == 1:
+ return {names[0]: tensor.select(dim=dim, index=0)}
+ # split the output tensor along the vertical dimension
+ tensors = torch.split(tensor, 1, dim=dim)
+ return {name: tensor.squeeze(dim=dim) for name, tensor in zip(names, tensors)}
+
+
+class Stacker:
+ """Handles extraction and stacking of data tensors for 3D variables."""
+
+ LEVEL_PATTERN = re.compile(r"_(\d+)$")
+
+ def __init__(
+ self,
+ prefix_map: Mapping[str, List[str]],
+ ):
+ """
+ Args:
+ prefix_map: Mapping which defines the correspondence between an arbitrary
+ set of "standard" names (e.g., "surface_pressure" or "air_temperature")
+ and lists of possible names or prefix variants (e.g., ["PRESsfc", "PS"]
+ or ["air_temperature_", "T_"]) found in the data.
+ """
+ self._prefix_map = prefix_map
+
+ @property
+ def prefix_map(self) -> Mapping[str, List[str]]:
+ """Mapping which defines the correspondence between an arbitrary set of
+ "standard" names (e.g., "surface_pressure" or "air_temperature") and
+ lists of possible names or prefix variants (e.g., ["PRESsfc", "PS"] or
+ ["air_temperature_", "T_"]) found in the data.
+ """
+ return self._prefix_map
+
+ @property
+ def standard_names(self) -> List[str]:
+ return list(self._prefix_map.keys())
+
+ def get_all_level_names(self, standard_name: str, data: TensorMapping) -> List[str]:
+ """Get the names of all variables in the data that match one of the
+ prefixes associated with the given standard name. If the standard name
+ corresponds to a 3D variable, returns all vertical level names in their
+ natural order.
+ """
+ if standard_name not in self.standard_names:
+ raise ValueError(f"{standard_name} is not a standard name.")
+ for prefix_or_name in self._prefix_map[standard_name]:
+ if prefix_or_name in data:
+ return [prefix_or_name]
+ try:
+ return self._natural_sort_names(prefix_or_name, data)
+ except KeyError:
+ pass
+ raise KeyError(
+ f"No prefix associated with {standard_name} was found in data keys."
+ )
+
+ def __call__(self, standard_name: str, data: TensorMapping) -> torch.Tensor:
+ """Extract the variable corresponding to standard name and return as a
+ 3D tensor.
+
+ """
+ return self._stack_levels_try(standard_name, data)
+
+ def _stack_levels(self, prefix_or_name: str, data: TensorMapping) -> torch.Tensor:
+ names = self._natural_sort_names(prefix_or_name, data)
+ # stack along the final dimension
+ return torch.stack([data[name] for name in names], dim=-1)
+
+ def _stack_levels_try(
+ self, standard_name: str, data: TensorMapping
+ ) -> torch.Tensor:
+ prefixes_or_names = self._prefix_map[standard_name]
+ for prefix_or_name in prefixes_or_names:
+ if prefix_or_name in data:
+ # 2D variable, return as 1-level 3D tensor
+ return data[prefix_or_name].unsqueeze(-1)
+ try:
+ return self._stack_levels(prefix_or_name, data)
+ except KeyError:
+ pass
+ raise KeyError(
+ f"Found no matches for any of {prefixes_or_names} "
+ f"among the data names {list(data.keys())}."
+ )
+
+ def _natural_sort_names(self, prefix: str, data: TensorMapping) -> List[str]:
+ names = [field_name for field_name in data if field_name.startswith(prefix)]
+
+ levels = []
+ for name in names:
+ match = self.LEVEL_PATTERN.search(name)
+ if match is None:
+ raise ValueError(
+ f"Invalid field name {name}, is a prefix variable "
+ "but does not end in _{number}."
+ )
+ levels.append(int(match.group(1)))
+
+ for i, level in enumerate(sorted(levels)):
+ if i != level:
+ raise ValueError(f"Missing level {i} in {prefix} levels {levels}.")
+
+ if len(names) == 0:
+ raise KeyError(prefix)
+
+ return natural_sort(names)
diff --git a/fme/fme/core/stepper.py b/fme/fme/core/stepper.py
deleted file mode 100644
index 9cea13b..0000000
--- a/fme/fme/core/stepper.py
+++ /dev/null
@@ -1,677 +0,0 @@
-import dataclasses
-import datetime
-from copy import copy
-from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
-
-import dacite
-import torch
-from torch import nn
-
-from fme.core.corrector import CorrectorConfig
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.data_loading.requirements import DataRequirements
-from fme.core.data_loading.utils import decode_timestep, encode_timestep
-from fme.core.device import get_device
-from fme.core.distributed import Distributed
-from fme.core.loss import WeightedMappingLossConfig
-from fme.core.normalizer import (
- FromStateNormalizer,
- NormalizationConfig,
- StandardNormalizer,
-)
-from fme.core.ocean import Ocean, OceanConfig
-from fme.core.packer import Packer
-from fme.core.registry import ModuleSelector
-
-from .optimization import NullOptimization, Optimization
-from .parameter_init import ParameterInitializationConfig
-from .typing_ import TensorDict, TensorMapping
-
-DEFAULT_TIMESTEP = datetime.timedelta(hours=6)
-DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP)
-
-
-@dataclasses.dataclass
-class SingleModuleStepperConfig:
- """
- Configuration for a single module stepper.
-
- Attributes:
- builder: The module builder.
- in_names: Names of input variables.
- out_names: Names of output variables.
- normalization: The normalization configuration.
- parameter_init: The parameter initialization configuration.
- ocean: The ocean configuration.
- loss: The loss configuration.
- corrector: The corrector configuration.
- next_step_forcing_names: Names of forcing variables for the next timestep.
- loss_normalization: The normalization configuration for the loss.
- residual_normalization: Optional alternative to configure loss normalization.
- If provided, it will be used for all *prognostic* variables in loss scaling.
- """
-
- builder: ModuleSelector
- in_names: List[str]
- out_names: List[str]
- normalization: Union[NormalizationConfig, FromStateNormalizer]
- parameter_init: ParameterInitializationConfig = dataclasses.field(
- default_factory=lambda: ParameterInitializationConfig()
- )
- ocean: Optional[OceanConfig] = None
- loss: WeightedMappingLossConfig = dataclasses.field(
- default_factory=lambda: WeightedMappingLossConfig()
- )
- corrector: CorrectorConfig = dataclasses.field(
- default_factory=lambda: CorrectorConfig()
- )
- next_step_forcing_names: List[str] = dataclasses.field(default_factory=list)
- loss_normalization: Optional[Union[NormalizationConfig, FromStateNormalizer]] = None
- residual_normalization: Optional[
- Union[NormalizationConfig, FromStateNormalizer]
- ] = None
-
- def __post_init__(self):
- for name in self.next_step_forcing_names:
- if name not in self.in_names:
- raise ValueError(
- f"next_step_forcing_name '{name}' not in in_names: {self.in_names}"
- )
- if name in self.out_names:
- raise ValueError(
- f"next_step_forcing_name is an output variable: '{name}'"
- )
- if (
- self.residual_normalization is not None
- and self.loss_normalization is not None
- ):
- raise ValueError(
- "Only one of residual_normalization, loss_normalization can "
- "be provided."
- "If residual_normalization is provided, it will be used for all "
- "*prognostic* variables in loss scalng. "
- "If loss_normalization is provided, it will be used for all variables "
- "in loss scaling."
- )
-
- def get_data_requirements(self, n_forward_steps: int) -> DataRequirements:
- return DataRequirements(
- names=self.all_names,
- n_timesteps=n_forward_steps + 1,
- )
-
- def get_forcing_data_requirements(self, n_forward_steps: int) -> DataRequirements:
- if self.ocean is None:
- names = self.forcing_names
- else:
- names = list(set(self.forcing_names).union(self.ocean.forcing_names))
-
- return DataRequirements(names=names, n_timesteps=n_forward_steps + 1)
-
- def get_state(self):
- return dataclasses.asdict(self)
-
- def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
- """
- If the model is being initialized from another model's weights for fine-tuning,
- returns those weights. Otherwise, returns None.
-
- The list mirrors the order of `modules` in the `SingleModuleStepper` class.
- """
- base_weights = self.parameter_init.get_base_weights()
- if base_weights is not None:
- return [base_weights]
- else:
- return None
-
- def get_stepper(
- self,
- img_shape: Tuple[int, int],
- area: Optional[torch.Tensor],
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- ):
- return SingleModuleStepper(
- config=self,
- img_shape=img_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
- timestep=timestep,
- )
-
- @classmethod
- def from_state(cls, state) -> "SingleModuleStepperConfig":
- state = cls.remove_deprecated_keys(state)
- return dacite.from_dict(
- data_class=cls, data=state, config=dacite.Config(strict=True)
- )
-
- @property
- def all_names(self):
- """Names of all variables required, including auxiliary ones."""
- extra_names = []
- if self.ocean is not None:
- extra_names.extend(self.ocean.forcing_names)
- all_names = list(set(self.in_names).union(self.out_names).union(extra_names))
- return all_names
-
- @property
- def normalize_names(self):
- """Names of variables which require normalization. I.e. inputs/outputs."""
- return list(set(self.in_names).union(self.out_names))
-
- @property
- def forcing_names(self) -> List[str]:
- """Names of variables which are inputs only."""
- return list(set(self.in_names) - set(self.out_names))
-
- @property
- def prognostic_names(self) -> List[str]:
- """Names of variables which both inputs and outputs."""
- return list(set(self.out_names).intersection(self.in_names))
-
- @classmethod
- def remove_deprecated_keys(cls, state: Dict[str, Any]) -> Dict[str, Any]:
- _unsupported_key_defaults = {
- "conserve_dry_air": False,
- "optimization": None,
- "conservation_loss": {"dry_air_penalty": None},
- }
- state_copy = state.copy()
- for key, default in _unsupported_key_defaults.items():
- if key in state_copy:
- if state_copy[key] == default or state_copy[key] is None:
- del state_copy[key]
- else:
- raise ValueError(
- f"The stepper config option {key} is deprecated and the setting"
- f" provided, {state_copy[key]}, is no longer implemented. The "
- "SingleModuleStepper being loaded from state cannot be run by "
- "this version of the code."
- )
- if "prescriber" in state_copy:
- # want to maintain backwards compatibility for this particular feature
- if state_copy["prescriber"] is not None:
- if state_copy.get("ocean") is not None:
- raise ValueError("Cannot specify both prescriber and ocean.")
- state_copy["ocean"] = {
- "surface_temperature_name": state_copy["prescriber"][
- "prescribed_name"
- ],
- "ocean_fraction_name": state_copy["prescriber"]["mask_name"],
- "interpolate": state_copy["prescriber"]["interpolate"],
- }
- del state_copy["prescriber"]
- return state_copy
-
-
-@dataclasses.dataclass
-class ExistingStepperConfig:
- """
- Configuration for an existing stepper. This is only designed to point to
- a serialized stepper checkpoint for loading, e.g., in the case of training
- resumption.
-
- Attributes:
- checkpoint_path: The path to the serialized checkpoint.
- """
-
- checkpoint_path: str
-
- def _load_checkpoint(self) -> Mapping[str, Any]:
- return torch.load(self.checkpoint_path, map_location=get_device())
-
- def get_data_requirements(self, n_forward_steps: int) -> DataRequirements:
- return SingleModuleStepperConfig.from_state(
- self._load_checkpoint()["stepper"]["config"]
- ).get_data_requirements(n_forward_steps)
-
- def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
- return SingleModuleStepperConfig.from_state(
- self._load_checkpoint()["stepper"]["config"]
- ).get_base_weights()
-
- def get_stepper(self, img_shape, area, sigma_coordinates, timestep):
- del img_shape # unused
- return SingleModuleStepper.from_state(
- self._load_checkpoint()["stepper"],
- area=area,
- sigma_coordinates=sigma_coordinates,
- )
-
-
-def _combine_normalizers(
- residual_normalizer: StandardNormalizer,
- model_normalizer: StandardNormalizer,
-) -> StandardNormalizer:
- # Combine residual and model normalizers by overwriting the model normalizer
- # values that are present in residual normalizer. The residual normalizer
- # is assumed to have a subset of prognostic keys only.
- means, stds = copy(model_normalizer.means), copy(model_normalizer.stds)
- means.update(residual_normalizer.means)
- stds.update(residual_normalizer.stds)
- return StandardNormalizer(means=means, stds=stds)
-
-
-def _cast_tensordict(
- data: TensorDict, dtype: Optional[torch.dtype] = None
-) -> TensorDict:
- device = get_device()
- return {name: value.to(device, dtype=dtype) for name, value in data.items()}
-
-
-def _prepend_timestep(
- data: TensorDict, timestep: TensorDict, time_dim: int = 1
-) -> TensorDict:
- return {
- k: torch.cat([timestep[k].unsqueeze(time_dim), v], dim=time_dim)
- for k, v in data.items()
- }
-
-
-@dataclasses.dataclass
-class SteppedData:
- metrics: TensorDict
- gen_data: TensorDict
- target_data: TensorDict
- gen_data_norm: TensorDict
- target_data_norm: TensorDict
-
- def remove_initial_condition(self) -> "SteppedData":
- return SteppedData(
- metrics=self.metrics,
- gen_data={k: v[:, 1:] for k, v in self.gen_data.items()},
- target_data={k: v[:, 1:] for k, v in self.target_data.items()},
- gen_data_norm={k: v[:, 1:] for k, v in self.gen_data_norm.items()},
- target_data_norm={k: v[:, 1:] for k, v in self.target_data_norm.items()},
- )
-
- def copy(self) -> "SteppedData":
- """Creates new dictionaries for the data but with the same tensors."""
- return SteppedData(
- metrics=self.metrics,
- gen_data={k: v for k, v in self.gen_data.items()},
- target_data={k: v for k, v in self.target_data.items()},
- gen_data_norm={k: v for k, v in self.gen_data_norm.items()},
- target_data_norm={k: v for k, v in self.target_data_norm.items()},
- )
-
- def prepend_initial_condition(
- self,
- initial_condition: TensorDict,
- normalized_initial_condition: TensorDict,
- ) -> "SteppedData":
- """
- Prepends an initial condition to the existing stepped data.
- """
- initial_condition = _cast_tensordict(initial_condition)
- normalized_initial_condition = _cast_tensordict(normalized_initial_condition)
-
- return SteppedData(
- metrics=self.metrics,
- gen_data=_prepend_timestep(self.gen_data, initial_condition),
- target_data=_prepend_timestep(self.target_data, initial_condition),
- gen_data_norm=_prepend_timestep(
- self.gen_data_norm, normalized_initial_condition
- ),
- target_data_norm=_prepend_timestep(
- self.target_data_norm, normalized_initial_condition
- ),
- )
-
-
-class SingleModuleStepper:
- """
- Stepper class for a single pytorch module.
- """
-
- TIME_DIM = 1
- CHANNEL_DIM = -3
-
- def __init__(
- self,
- config: SingleModuleStepperConfig,
- img_shape: Tuple[int, int],
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- timestep: datetime.timedelta,
- init_weights: bool = True,
- ):
- """
- Args:
- config: The configuration.
- img_shape: Shape of domain as (n_lat, n_lon).
- area: (n_lat, n_lon) array containing relative gridcell area,
- in any units including unitless.
- sigma_coordinates: The sigma coordinates.
- timestep: Timestep of the model.
- init_weights: Whether to initialize the weights. Should pass False if
- the weights are about to be overwritten by a checkpoint.
- """
- n_in_channels = len(config.in_names)
- n_out_channels = len(config.out_names)
- self.in_packer = Packer(config.in_names)
- self.out_packer = Packer(config.out_names)
- self.normalizer = config.normalization.build(config.normalize_names)
- if config.ocean is not None:
- self.ocean: Optional[Ocean] = config.ocean.build(
- config.in_names, config.out_names, timestep
- )
- else:
- self.ocean = None
- self.module = config.builder.build(
- n_in_channels=n_in_channels,
- n_out_channels=n_out_channels,
- img_shape=img_shape,
- )
- module, self._l2_sp_tuning_regularizer = config.parameter_init.apply(
- self.module, init_weights=init_weights
- )
- self.module = module.to(get_device())
-
- self._img_shape = img_shape
- self._config = config
- self._no_optimization = NullOptimization()
-
- dist = Distributed.get_instance()
- self._is_distributed = dist.is_distributed()
- self.module = dist.wrap_module(self.module)
-
- self.area = area.to(get_device())
- self.sigma_coordinates = sigma_coordinates.to(get_device())
- self.timestep = timestep
-
- self.loss_obj = config.loss.build(self.area, config.out_names, self.CHANNEL_DIM)
-
- self._corrector = config.corrector.build(
- area=self.area,
- sigma_coordinates=self.sigma_coordinates,
- timestep=timestep,
- )
- if config.loss_normalization is not None:
- self.loss_normalizer = config.loss_normalization.build(
- names=config.normalize_names
- )
- elif config.residual_normalization is not None:
- # Use residual norm for prognostic variables and input/output
- # normalizer for diagnostic variables in loss
- self.loss_normalizer = _combine_normalizers(
- residual_normalizer=config.residual_normalization.build(
- config.prognostic_names
- ),
- model_normalizer=self.normalizer,
- )
- else:
- self.loss_normalizer = self.normalizer
-
- def get_data_requirements(self, n_forward_steps: int) -> DataRequirements:
- return self._config.get_data_requirements(n_forward_steps)
-
- @property
- def effective_loss_scaling(self) -> TensorMapping:
- """
- Effective loss scalings used to normalize outputs before computing loss.
- y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i
- where loss_scaling_i = loss_normalizer_std_i / weight_i
- """
- custom_weights = self._config.loss.weights
- loss_normalizer_stds = self.loss_normalizer.stds
- return {
- k: loss_normalizer_stds[k] / custom_weights.get(k, 1.0)
- for k in self._config.out_names
- }
-
- @property
- def prognostic_names(self) -> List[str]:
- return sorted(
- list(set(self.out_packer.names).intersection(self.in_packer.names))
- )
-
- @property
- def modules(self) -> nn.ModuleList:
- """
- Returns:
- A list of modules being trained.
- """
- return nn.ModuleList([self.module])
-
- def step(
- self,
- input: TensorMapping,
- ocean_data: TensorMapping,
- ) -> TensorDict:
- """
- Step the model forward one timestep given input data.
-
- Args:
- input: Mapping from variable name to tensor of shape
- [n_batch, n_lat, n_lon]. This data is used as input for `self.module`
- and is assumed to contain all input variables and be denormalized.
- ocean_data: Mapping from variable name to tensor of shape
- [n_batch, n_lat, n_lon]. This must contain the necessary data at the
- output timestep for the ocean model (e.g. surface temperature,
- mixed-layer depth etc.).
-
- Returns:
- The denormalized output data at the next time step.
- """
- input_norm = self.normalizer.normalize(input)
- input_tensor = self.in_packer.pack(input_norm, axis=self.CHANNEL_DIM)
- output_tensor = self.module(input_tensor)
- output_norm = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM)
- output = self.normalizer.denormalize(output_norm)
- if self._corrector is not None:
- output = self._corrector(input, output)
- if self.ocean is not None:
- output = self.ocean(ocean_data, input, output)
- return output
-
- def predict(
- self,
- initial_condition: TensorMapping,
- forcing_data: TensorMapping,
- n_forward_steps: int,
- ) -> TensorDict:
- """
- Predict multiple steps forward given initial condition and forcing data.
-
- Args:
- initial_condition: Mapping from variable name to tensors of shape
- [n_batch, n_lat, n_lon]. This data is assumed to contain all prognostic
- variables and be denormalized.
- forcing_data: Mapping from variable name to tensors of shape
- [n_batch, n_forward_steps + 1, n_lat, n_lon]. This contains the forcing
- and ocean data for the initial condition and all subsequent timesteps.
- n_forward_steps: The number of timesteps to run the model forward for.
-
- Returns:
- The denormalized output data for all the forward timesteps. Shape of
- each tensor will be [n_batch, n_forward_steps, n_lat, n_lon].
- """
- output_list = []
- state = initial_condition
- forcing_names = self._config.forcing_names
- ocean_forcing_names = self.ocean.forcing_names if self.ocean is not None else []
- for step in range(n_forward_steps):
- current_step_forcing = {
- k: (
- forcing_data[k][:, step]
- if k not in self._config.next_step_forcing_names
- else forcing_data[k][:, step + 1]
- )
- for k in forcing_names
- }
- next_step_ocean_data = {
- k: forcing_data[k][:, step + 1] for k in ocean_forcing_names
- }
- input_data = {**state, **current_step_forcing}
- state = self.step(input_data, next_step_ocean_data)
- output_list.append(state)
- output_timeseries = {}
- for name in state:
- output_timeseries[name] = torch.stack(
- [x[name] for x in output_list], dim=self.TIME_DIM
- )
- return output_timeseries
-
- def get_initial_condition(self, data: TensorDict) -> Tuple[TensorDict, TensorDict]:
- ic = {k: v.select(self.TIME_DIM, 0) for k, v in data.items()}
- return ic, self.normalizer.normalize(ic)
-
- def run_on_batch(
- self,
- data: TensorDict,
- optimization: Union[Optimization, NullOptimization],
- n_forward_steps: int = 1,
- ) -> SteppedData:
- """
- Step the model forward multiple steps on a batch of data.
-
- Args:
- data: The batch data where each tensor has shape
- [n_sample, n_forward_steps + 1, n_lat, n_lon].
- optimization: The optimization class to use for updating the module.
- Use `NullOptimization` to disable training.
- n_forward_steps: The number of timesteps to run the model for.
- aggregator: The data aggregator.
-
- Returns:
- The loss metrics, the generated data, the normalized generated data,
- and the normalized batch data.
- """
- data = _cast_tensordict(data, dtype=torch.float)
- time_dim = self.TIME_DIM
- if self.ocean is None:
- forcing_names = self._config.forcing_names
- else:
- forcing_names = self._config.forcing_names + self.ocean.forcing_names
- forcing_data = {k: data[k] for k in forcing_names}
-
- loss = torch.tensor(0.0, device=get_device())
- metrics = {}
-
- input_data = {
- k: data[k].select(time_dim, 0) for k in self._config.prognostic_names
- }
- # Remove the initial condition from target data
- data = {k: data[k][:, 1:] for k in data}
-
- optimization.set_mode(self.module)
- with optimization.autocast():
- # output from self.predict does not include initial condition
- gen_data = self.predict(input_data, forcing_data, n_forward_steps)
-
- # compute loss for each timestep
- for step in range(n_forward_steps):
- gen_step = {k: v.select(time_dim, step) for k, v in gen_data.items()}
- target_step = {k: v.select(time_dim, step) for k, v in data.items()}
- gen_norm_step = self.loss_normalizer.normalize(gen_step)
- target_norm_step = self.loss_normalizer.normalize(target_step)
-
- step_loss = self.loss_obj(gen_norm_step, target_norm_step)
- loss += step_loss
- metrics[f"loss_step_{step}"] = step_loss.detach()
-
- loss += self._l2_sp_tuning_regularizer()
-
- metrics["loss"] = loss.detach()
- optimization.step_weights(loss)
-
- gen_data_norm = self.normalizer.normalize(gen_data)
- full_data_norm = self.normalizer.normalize(data)
-
- return SteppedData(
- metrics=metrics,
- gen_data=gen_data,
- target_data=data,
- gen_data_norm=gen_data_norm,
- target_data_norm=full_data_norm,
- )
-
- def get_state(self):
- """
- Returns:
- The state of the stepper.
- """
- return {
- "module": self.module.state_dict(),
- "normalizer": self.normalizer.get_state(),
- "img_shape": self._img_shape,
- "config": self._config.get_state(),
- "area": self.area,
- "sigma_coordinates": self.sigma_coordinates.as_dict(),
- "encoded_timestep": encode_timestep(self.timestep),
- "loss_normalizer": self.loss_normalizer.get_state(),
- }
-
- def load_state(self, state):
- """
- Load the state of the stepper.
-
- Args:
- state: The state to load.
- """
- if "module" in state:
- module = state["module"]
- if "module.device_buffer" in module:
- # for backwards compatibility with old checkpoints
- del module["module.device_buffer"]
- self.module.load_state_dict(module)
-
- @classmethod
- def from_state(
- cls,
- state,
- area: torch.Tensor,
- sigma_coordinates: SigmaCoordinates,
- ) -> "SingleModuleStepper":
- """
- Load the state of the stepper.
-
- Args:
- state: The state to load.
- area: (n_lat, n_lon) array containing relative gridcell area, in any
- units including unitless.
- sigma_coordinates: The sigma coordinates.
-
- Returns:
- The stepper.
- """
- config = {**state["config"]} # make a copy to avoid mutating input
- config["normalization"] = FromStateNormalizer(state["normalizer"])
-
- # for backwards compatibility with previous steppers created w/o
- # loss_normalization or residual_normalization
- loss_normalizer_state = state.get("loss_normalizer", state["normalizer"])
- config["loss_normalization"] = FromStateNormalizer(loss_normalizer_state)
- # Overwrite the residual_normalization key if it exists, since the combined
- # loss scalings are saved in initial training as the loss_normalization
- config["residual_normalization"] = None
-
- area = state.get("area", area)
- if "sigma_coordinates" in state:
- sigma_coordinates = dacite.from_dict(
- data_class=SigmaCoordinates,
- data=state["sigma_coordinates"],
- config=dacite.Config(strict=True),
- )
- encoded_timestep = state.get("encoded_timestep", DEFAULT_ENCODED_TIMESTEP)
- timestep = decode_timestep(encoded_timestep)
- if "img_shape" in state:
- img_shape = state["img_shape"]
- else:
- # this is for backwards compatibility with old checkpoints
- for v in state["data_shapes"].values():
- img_shape = v[-2:]
- break
- stepper = cls(
- config=SingleModuleStepperConfig.from_state(config),
- img_shape=img_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
- timestep=timestep,
- # don't need to initialize weights, we're about to load_state
- init_weights=False,
- )
- stepper.load_state(state)
- return stepper
diff --git a/fme/fme/core/test_climate_data.py b/fme/fme/core/test_climate_data.py
index 5733de9..36db0ed 100644
--- a/fme/fme/core/test_climate_data.py
+++ b/fme/fme/core/test_climate_data.py
@@ -3,23 +3,7 @@
import pytest
import torch
-from fme.core.climate_data import (
- ClimateData,
- _height_at_interface,
- _layer_thickness,
- _pressure_at_interface,
- natural_sort,
-)
-
-
-def test__pressure_at_interface():
- ak = torch.tensor([2.0, 0.5, 0.0])
- bk = torch.tensor([0.0, 0.5, 1.0])
- psfc = torch.tensor([[1, 1], [2, 2]])
- pinterface = _pressure_at_interface(ak=ak, bk=bk, surface_pressure=psfc)
- assert pinterface.shape == (2, 2, 3)
- assert pinterface[0, 0, 0] == ak[0]
- assert pinterface[0, 0, -1] == bk[-1] * psfc[0, 0]
+from fme.core.climate_data import ClimateData, _height_at_interface, _layer_thickness
def test__layer_thickness():
@@ -54,58 +38,6 @@ def test__height_at_interface():
)
-@pytest.mark.parametrize(
- "names, sorted_names",
- [
- (
- ["a_1", "b_1", "c_1", "a_2"],
- [
- "a_1",
- "a_2",
- "b_1",
- "c_1",
- ],
- ),
- (
- [
- "a_0",
- "a_1",
- "a_12",
- "a_2",
- ],
- [
- "a_0",
- "a_1",
- "a_2",
- "a_12",
- ],
- ),
- (
- [
- "a_0001",
- "a_0012",
- "a_0002",
- ],
- [
- "a_0001",
- "a_0002",
- "a_0012",
- ],
- ),
- (
- [
- "ab1",
- "aa10",
- "aa2",
- ],
- ["aa2", "aa10", "ab1"],
- ),
- ],
-)
-def test_natural_sort(names, sorted_names):
- assert natural_sort(names) == sorted_names
-
-
@pytest.mark.parametrize("has_water_variable", [True, False])
def test_missing_specific_total_water(has_water_variable):
"""Check shape of specific total water and make sure that it returns None
@@ -176,5 +108,5 @@ def _get_data(missing_water_layer: bool):
2,
)
else:
- with pytest.raises(KeyError):
+ with pytest.raises(ValueError):
_ = climate_data.specific_total_water
diff --git a/fme/fme/core/test_coordinates.py b/fme/fme/core/test_coordinates.py
new file mode 100644
index 0000000..04c3a87
--- /dev/null
+++ b/fme/fme/core/test_coordinates.py
@@ -0,0 +1,115 @@
+import pytest
+import torch
+
+from fme.core.coordinates import (
+ HEALPixCoordinates,
+ HybridSigmaPressureCoordinate,
+ LatLonCoordinates,
+)
+
+
+@pytest.mark.parametrize(
+ "first, second",
+ [
+ (
+ HybridSigmaPressureCoordinate(
+ ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6])
+ ),
+ HybridSigmaPressureCoordinate(
+ ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6])
+ ),
+ ),
+ (
+ LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])),
+ LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])),
+ ),
+ (
+ HEALPixCoordinates(
+ face=torch.tensor([1, 2, 3]),
+ height=torch.tensor([4, 5, 6]),
+ width=torch.tensor([7, 8, 9]),
+ ),
+ HEALPixCoordinates(
+ face=torch.tensor([1, 2, 3]),
+ height=torch.tensor([4, 5, 6]),
+ width=torch.tensor([7, 8, 9]),
+ ),
+ ),
+ ],
+)
+def test_equality(first, second):
+ assert first == second
+
+
+@pytest.mark.parametrize(
+ "first, second",
+ [
+ (
+ HybridSigmaPressureCoordinate(
+ ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6])
+ ),
+ HybridSigmaPressureCoordinate(
+ ak=torch.tensor([1, 2, 3]), bk=torch.tensor([5, 6, 7])
+ ),
+ ),
+ (
+ LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])),
+ LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([5, 6, 7])),
+ ),
+ (
+ HEALPixCoordinates(
+ face=torch.tensor([1, 2, 3]),
+ height=torch.tensor([4, 5, 6]),
+ width=torch.tensor([7, 8, 9]),
+ ),
+ HEALPixCoordinates(
+ face=torch.tensor([1, 2, 3]),
+ height=torch.tensor([4, 5, 6]),
+ width=torch.tensor([8, 9, 10]),
+ ),
+ ),
+ (
+ LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])),
+ HEALPixCoordinates(
+ face=torch.tensor([1, 2, 3]),
+ height=torch.tensor([4, 5, 6]),
+ width=torch.tensor([7, 8, 9]),
+ ),
+ ),
+ ],
+)
+def test_inequality(first, second):
+ assert first != second
+
+
+def test_vertical_integral_shape():
+ nlat, nlon, nz = 4, 8, 3
+ water = torch.rand(nlat, nlon, nz)
+ pressure = torch.rand(nlat, nlon)
+ ak, bk = torch.arange(nz + 1), torch.arange(nz + 1)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ water_path = coords.vertical_integral(water, pressure)
+ assert water_path.shape == (nlat, nlon)
+
+
+def test_vertical_coordinates_raises_value_error():
+ ak, bk = torch.arange(3), torch.arange(4)
+ with pytest.raises(ValueError):
+ HybridSigmaPressureCoordinate(ak, bk)
+
+
+def test_vertical_coordinates_len():
+ ak, bk = torch.arange(3), torch.arange(3)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ assert len(coords) == 3
+
+
+def test_interface_pressure():
+ ak = torch.tensor([2.0, 0.5, 0.0])
+ bk = torch.tensor([0.0, 0.5, 1.0])
+ psfc = torch.tensor([[1, 1], [2, 2]])
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ pinterface = coords.interface_pressure(psfc)
+ assert pinterface.shape == (2, 2, 3)
+ assert pinterface[0, 0, 0] == ak[0]
+ assert pinterface[0, 0, -1] == bk[-1] * psfc[0, 0]
diff --git a/fme/fme/core/test_corrector.py b/fme/fme/core/test_corrector.py
deleted file mode 100644
index 72a7b46..0000000
--- a/fme/fme/core/test_corrector.py
+++ /dev/null
@@ -1,199 +0,0 @@
-import datetime
-
-import numpy as np
-import pytest
-import torch
-
-from fme.ace.inference.derived_variables import total_water_path_budget_residual
-from fme.core import ClimateData, metrics
-from fme.core.corrector import (
- _force_conserve_dry_air,
- _force_conserve_moisture,
- _force_positive,
- _force_zero_global_mean_moisture_advection,
-)
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.loss import get_dry_air_nonconservation
-
-TIMESTEP = datetime.timedelta(hours=6)
-
-
-def test_force_no_global_mean_moisture_advection():
- torch.random.manual_seed(0)
- data = {
- "tendency_of_total_water_path_due_to_advection": torch.rand(size=(3, 2, 5, 5)),
- }
- area_weights = 1.0 + torch.rand(size=(5, 5))
- original_mean = metrics.weighted_mean(
- data["tendency_of_total_water_path_due_to_advection"],
- weights=area_weights,
- dim=[-2, -1],
- )
- assert (original_mean.abs() > 0.0).all()
- fixed_data = _force_zero_global_mean_moisture_advection(
- data,
- area=area_weights,
- )
- new_mean = metrics.weighted_mean(
- fixed_data["tendency_of_total_water_path_due_to_advection"],
- weights=area_weights,
- dim=[-2, -1],
- )
- assert (new_mean.abs() < original_mean.abs()).all()
- np.testing.assert_almost_equal(new_mean.cpu().numpy(), 0.0, decimal=6)
-
-
-def test_force_conserve_dry_air():
- torch.random.manual_seed(0)
- data = {
- "PRESsfc": 10.0 + torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)),
- }
- sigma_coordinates = SigmaCoordinates(
- ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0])
- )
- area_weights = 1.0 + torch.rand(size=(5, 5))
- original_nonconservation = get_dry_air_nonconservation(
- data,
- sigma_coordinates=sigma_coordinates,
- area_weights=area_weights,
- )
- assert original_nonconservation > 0.0
- in_data = {k: v.select(dim=1, index=0) for k, v in data.items()}
- out_data = {k: v.select(dim=1, index=1) for k, v in data.items()}
- fixed_out_data = _force_conserve_dry_air(
- in_data,
- out_data,
- sigma_coordinates=sigma_coordinates,
- area=area_weights,
- )
- new_data = {
- k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items()
- }
- new_nonconservation = get_dry_air_nonconservation(
- new_data,
- sigma_coordinates=sigma_coordinates,
- area_weights=area_weights,
- )
- assert new_nonconservation < original_nonconservation
- np.testing.assert_almost_equal(new_nonconservation.cpu().numpy(), 0.0, decimal=6)
-
-
-@pytest.mark.parametrize("fv3_data", [True, False])
-@pytest.mark.parametrize(
- "global_only, terms_to_modify",
- [
- (True, "precipitation"),
- (True, "evaporation"),
- (False, "advection_and_precipitation"),
- (False, "advection_and_evaporation"),
- ],
-)
-def test_force_conserve_moisture(fv3_data: bool, global_only: bool, terms_to_modify):
- torch.random.manual_seed(0)
- if fv3_data:
- data = {
- "PRESsfc": 10.0 + torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)),
- "PRATEsfc": torch.rand(size=(3, 2, 5, 5)),
- "LHTFLsfc": torch.rand(size=(3, 2, 5, 5)),
- "tendency_of_total_water_path_due_to_advection": torch.rand(
- size=(3, 2, 5, 5)
- ),
- }
- else:
- data = {
- "PS": 10.0 + torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)),
- "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)),
- "surface_precipitation_rate": torch.rand(size=(3, 2, 5, 5)),
- "LHFLX": torch.rand(size=(3, 2, 5, 5)),
- "tendency_of_total_water_path_due_to_advection": torch.rand(
- size=(3, 2, 5, 5)
- ),
- }
- sigma_coordinates = SigmaCoordinates(
- ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0])
- )
- area_weights = 1.0 + torch.rand(size=(5, 5))
- data["tendency_of_total_water_path_due_to_advection"] -= metrics.weighted_mean(
- data["tendency_of_total_water_path_due_to_advection"],
- weights=area_weights,
- dim=[-2, -1],
- )[..., None, None]
- original_budget_residual = total_water_path_budget_residual(
- ClimateData(data),
- sigma_coordinates=sigma_coordinates,
- timestep=TIMESTEP,
- )[
- :, 1
- ] # no meaning for initial value data, want first timestep
- if global_only:
- original_budget_residual = metrics.weighted_mean(
- original_budget_residual, weights=area_weights, dim=[-2, -1]
- )
- original_budget_residual = original_budget_residual.cpu().numpy()
- original_dry_air = (
- ClimateData(data)
- .surface_pressure_due_to_dry_air(sigma_coordinates)
- .cpu()
- .numpy()
- )
- assert np.any(np.abs(original_budget_residual) > 0.0)
- in_data = {k: v.select(dim=1, index=0) for k, v in data.items()}
- out_data = {k: v.select(dim=1, index=1) for k, v in data.items()}
- fixed_out_data = _force_conserve_moisture(
- in_data,
- out_data,
- sigma_coordinates=sigma_coordinates,
- area=area_weights,
- timestep=TIMESTEP,
- terms_to_modify=terms_to_modify,
- )
- new_data = {
- k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items()
- }
- new_budget_residual = total_water_path_budget_residual(
- ClimateData(new_data),
- sigma_coordinates=sigma_coordinates,
- timestep=TIMESTEP,
- )[
- :, 1
- ] # no meaning for initial value data, want first timestep
- new_dry_air = (
- ClimateData(data)
- .surface_pressure_due_to_dry_air(sigma_coordinates)
- .cpu()
- .numpy()
- )
-
- global_budget_residual = (
- metrics.weighted_mean(new_budget_residual, weights=area_weights, dim=[-2, -1])
- .cpu()
- .numpy()
- )
- np.testing.assert_almost_equal(global_budget_residual, 0.0, decimal=6)
-
- if not global_only:
- new_budget_residual = new_budget_residual.cpu().numpy()
- assert np.all(np.abs(new_budget_residual) < np.abs(original_budget_residual))
- np.testing.assert_almost_equal(new_budget_residual, 0.0, decimal=6)
-
- np.testing.assert_almost_equal(new_dry_air, original_dry_air, decimal=6)
-
-
-def test__force_positive():
- data = {
- "foo": torch.tensor([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]),
- "bar": torch.tensor([[-1.0, 0.0], [0.0, -3.0], [1.0, 2.0]]),
- }
- original_min = torch.min(data["foo"])
- assert original_min < 0.0
- fixed_data = _force_positive(data, ["foo"])
- new_min = torch.min(fixed_data["foo"])
- # Ensure the minimum value of 'foo' is now 0
- torch.testing.assert_close(new_min, torch.tensor(0.0))
- # Ensure other variables are not modified
- torch.testing.assert_close(fixed_data["bar"], data["bar"])
diff --git a/fme/fme/core/test_distributed.py b/fme/fme/core/test_distributed.py
index b904dcc..4105c77 100644
--- a/fme/fme/core/test_distributed.py
+++ b/fme/fme/core/test_distributed.py
@@ -1,6 +1,8 @@
import pytest
import torch
+from fme import get_device
+
from .distributed import pad_tensor_at_end, unpad_tensor_at_end
@@ -32,7 +34,7 @@ def test_pad_tensor_at_end(padding, fill_value):
],
)
def test_pad_unpad_rountrip(padding):
- tensor = torch.ones(2, 3, 4)
+ tensor = torch.ones(2, 3, 4, device=get_device())
padded_tensor = pad_tensor_at_end(tensor, padding)
unpadded_tensor = unpad_tensor_at_end(padded_tensor, padding)
assert unpadded_tensor.size() == tensor.size()
diff --git a/fme/fme/core/test_gridded_ops.py b/fme/fme/core/test_gridded_ops.py
new file mode 100644
index 0000000..910fcf6
--- /dev/null
+++ b/fme/fme/core/test_gridded_ops.py
@@ -0,0 +1,39 @@
+from typing import Any, Dict, Type
+
+import pytest
+import torch
+
+from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
+
+
+@pytest.mark.parametrize(
+ "state, expected_class",
+ [
+ (
+ {
+ "type": "LatLonOperations",
+ "state": {"area_weights": torch.tensor([1.0, 2.0])},
+ },
+ LatLonOperations,
+ ),
+ (
+ {
+ "type": "HEALPixOperations",
+ "state": {},
+ },
+ HEALPixOperations,
+ ),
+ ],
+)
+def test_gridded_operations_from_state(
+ state: Dict[str, Any],
+ expected_class: Type[GriddedOperations],
+):
+ ops = GriddedOperations.from_state(state)
+ assert isinstance(ops, expected_class)
+
+ recovered_state = ops.to_state()
+ assert recovered_state == state
+
+ with pytest.raises(RuntimeError):
+ expected_class.from_state(state["state"])
diff --git a/fme/fme/core/test_histogram.py b/fme/fme/core/test_histogram.py
index e81c500..dfb1339 100644
--- a/fme/fme/core/test_histogram.py
+++ b/fme/fme/core/test_histogram.py
@@ -79,20 +79,28 @@ def test_dynamic_histogram_random_values(n_times: int, time_bin_len: int):
def test_dynamic_histogram_extends_as_expected():
histogram = DynamicHistogram(n_times=1, n_bins=200)
histogram.add(torch.as_tensor([[-1.0, 0.0, 1.0]]))
- np.testing.assert_approx_equal(histogram.bin_edges[0], -1.0, significant=6)
- np.testing.assert_approx_equal(histogram.bin_edges[-1], 1.0, significant=6)
+ bin_edges = histogram.bin_edges
+ assert bin_edges is not None
+ np.testing.assert_approx_equal(bin_edges[0], -1.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[-1], 1.0, significant=6)
histogram.add(torch.as_tensor([[-2.0]]))
+ bin_edges = histogram.bin_edges
+ assert bin_edges is not None
# double in size to the left, length becomes 4, from -3 to 1.0
- np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6)
- np.testing.assert_approx_equal(histogram.bin_edges[-1], 1.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[-1], 1.0, significant=6)
histogram.add(torch.as_tensor([[2.0]]))
+ bin_edges = histogram.bin_edges
+ assert bin_edges is not None
# double in size to the right, length becomes 8, from -3 to 5.0
- np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6)
- np.testing.assert_approx_equal(histogram.bin_edges[-1], 5.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[-1], 5.0, significant=6)
histogram.add(torch.as_tensor([[27.0]]))
+ bin_edges = histogram.bin_edges
+ assert bin_edges is not None
# double in size twice to the right, length becomes 32, from -3 to 29.0
- np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6)
- np.testing.assert_approx_equal(histogram.bin_edges[-1], 29.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6)
+ np.testing.assert_approx_equal(bin_edges[-1], 29.0, significant=6)
def test_histogram_handles_uniform_field():
@@ -134,8 +142,8 @@ def test_compared_dynamic_histograms(shape, percentiles):
for data_type in ["target", "prediction"]:
assert isinstance(wandb_result[f"{var_name}"], matplotlib.figure.Figure)
for p in percentiles:
- assert (
- wandb_result[f"{data_type}/{p}th-percentile/{var_name}"].shape == ()
+ assert isinstance(
+ wandb_result[f"{data_type}/{p}th-percentile/{var_name}"], float
)
ds = histogram.get_dataset()
diff --git a/fme/fme/core/test_loss.py b/fme/fme/core/test_loss.py
index dbb927c..52952d0 100644
--- a/fme/fme/core/test_loss.py
+++ b/fme/fme/core/test_loss.py
@@ -3,6 +3,7 @@
from fme.core import metrics
from fme.core.device import get_device
+from fme.core.gridded_ops import LatLonOperations
from fme.core.loss import (
AreaWeightedMSELoss,
GlobalMeanLoss,
@@ -19,7 +20,7 @@
def test_loss_builds_and_runs(global_mean_type):
config = LossConfig(global_mean_type=global_mean_type)
area = torch.randn(10, 10, device=get_device())
- loss = config.build(area, reduction="mean")
+ loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean")
x = torch.randn(10, 10, 10, 10, 10, device=get_device())
y = torch.randn(10, 10, 10, 10, 10, device=get_device())
result = loss(x, y)
@@ -28,36 +29,47 @@ def test_loss_builds_and_runs(global_mean_type):
def test_loss_of_zeros_is_variance():
+ torch.manual_seed(0)
config = LossConfig(global_mean_type=None)
area = torch.randn(10, 10, device=get_device())
- loss = config.build(area, reduction="mean")
+ loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean")
x = torch.zeros(10, 10, 10, 10, 10, device=get_device())
y = torch.randn(10, 10, 10, 10, 10, device=get_device())
result = loss(x, y)
assert result.shape == ()
assert isinstance(result, torch.Tensor)
- torch.testing.assert_close(result, y.var())
+ tol = {"rtol": 1e-4, "atol": 1e-4} if str(get_device()).startswith("cuda") else {}
+ torch.testing.assert_close(result, y.var(), **tol)
@pytest.mark.parametrize("global_mean_weight", [0.0, 1.0, 5.0])
def test_loss_of_zeros_is_one_plus_global_mean_weight(global_mean_weight: float):
+ torch.manual_seed(0)
config = LossConfig(
global_mean_type="LpLoss", global_mean_weight=global_mean_weight
)
area = torch.randn(10, 10, device=get_device())
- loss = config.build(area, reduction="mean")
+ loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean")
x = torch.zeros(10, 10, 10, 10, 10, device=get_device())
y = torch.randn(10, 10, 10, 10, 10, device=get_device())
result = loss(x, y)
assert result.shape == ()
assert isinstance(result, torch.Tensor)
expected = torch.tensor(1.0 + global_mean_weight)
- torch.testing.assert_close(result.cpu(), expected, atol=0.01, rtol=0)
+ tol = (
+ {"atol": 0.015, "rtol": 0.01}
+ if str(get_device()).startswith("cuda")
+ else {"atol": 0.01, "rtol": 0.0}
+ )
+ torch.testing.assert_close(result.cpu(), expected, **tol)
def test_global_mean_loss():
+ torch.manual_seed(0)
area = torch.randn(10, 10, device=get_device())
- loss = GlobalMeanLoss(area=area, loss=torch.nn.MSELoss())
+ loss = GlobalMeanLoss(
+ LatLonOperations(area).area_weighted_mean, loss=torch.nn.MSELoss()
+ )
x = torch.zeros(10, 10, 10, 10, 10, device=get_device())
y = torch.randn(10, 10, 10, 10, 10, device=get_device())
result = loss(x, y)
@@ -77,7 +89,7 @@ def test_area_weighted_mse():
x = torch.rand(10, 10).to(get_device())
target = torch.rand(10, 10).to(get_device())
area = torch.rand(10, 10).to(get_device())
- area_weighted_mse = AreaWeightedMSELoss(area)
+ area_weighted_mse = AreaWeightedMSELoss(LatLonOperations(area).area_weighted_mean)
result = area_weighted_mse(x, target)
expected = metrics.weighted_mean(
torch.nn.MSELoss(reduction="none")(x, target), weights=area, dim=(-2, -1)
@@ -160,9 +172,10 @@ def test_WeightedMappingLossConfig_no_weights():
out_names = [f"var_{i}" for i in range(n_channels)]
channel_dim = -3
area = torch.tensor([]) # area not used by this config
+ area_weighted_mean = LatLonOperations(area).area_weighted_mean
mapping_loss_config = WeightedMappingLossConfig()
- loss = loss_config.build(area, reduction="mean")
- mapping_loss = mapping_loss_config.build(area, out_names, channel_dim)
+ loss = loss_config.build(area_weighted_mean, reduction="mean")
+ mapping_loss = mapping_loss_config.build(area_weighted_mean, out_names, channel_dim)
packer = Packer(out_names)
x_mapping = {name: torch.randn(4, 5, 5).to(get_device()) for name in out_names}
@@ -180,7 +193,9 @@ def test_WeightedMappingLossConfig_weights():
mapping_loss_config = WeightedMappingLossConfig(
type="MSE", weights={"var_0": 4.0, "var_1": 1.0}
)
- mapping_loss = mapping_loss_config.build(area, out_names, channel_dim)
+ mapping_loss = mapping_loss_config.build(
+ LatLonOperations(area).area_weighted_mean, out_names, channel_dim
+ )
x0 = torch.ones(4, 5, 5).to(get_device())
x1 = 2.0 * x0
diff --git a/fme/fme/core/test_masking.py b/fme/fme/core/test_masking.py
new file mode 100644
index 0000000..5a32ce8
--- /dev/null
+++ b/fme/fme/core/test_masking.py
@@ -0,0 +1,120 @@
+import pytest
+import torch
+
+import fme
+from fme.core.masking import MaskingConfig
+from fme.core.stacker import Stacker
+
+
+def test_masking_config():
+ config = MaskingConfig(
+ mask_name="a",
+ surface_mask_name="b",
+ mask_value=1,
+ fill_value=0.0,
+ )
+ mask = config.build()
+ assert mask.mask_name == "a"
+ assert mask.mask_value == 1
+ assert mask.fill_value == 0.0
+
+ with pytest.raises(ValueError) as err:
+ _ = MaskingConfig(
+ mask_name="a",
+ surface_mask_name="b",
+ mask_value=3,
+ fill_value=0.0,
+ )
+ assert "mask_value must be either 0 or 1" in str(err.value)
+
+
+_SIZE = (4, 4)
+
+_MASK_DATA = {
+ "surface_mask": torch.ones(_SIZE, device=fme.get_device()),
+ "mask_0": torch.ones(_SIZE, device=fme.get_device()),
+ "mask_1": torch.ones(_SIZE, device=fme.get_device()),
+}
+_MASK_DATA["surface_mask"][1, 1] = 0
+_MASK_DATA["mask_0"][0, :] = 0
+_MASK_DATA["mask_1"][1, :] = 0
+
+
+_DATA = {
+ "PRESsfc": 10.0 + torch.rand(size=_SIZE, device=fme.get_device()),
+ "specific_total_water_0": torch.rand(size=_SIZE, device=fme.get_device()),
+ "specific_total_water_1": torch.rand(size=_SIZE, device=fme.get_device()),
+}
+
+
+def test_masking():
+ config = MaskingConfig(
+ mask_name="mask",
+ surface_mask_name="surface_mask",
+ mask_value=0,
+ fill_value=0.0,
+ )
+ mask = config.build()
+ stacker = Stacker(
+ {
+ "surface_pressure": ["PRESsfc", "PS"],
+ "specific_total_water": ["specific_total_water_"],
+ }
+ )
+ output = mask(stacker, _DATA, _MASK_DATA)
+ assert output["PRESsfc"][1, 1] == 0.0
+ assert output["PRESsfc"][0, 1] != 0.0
+ assert torch.all(output["specific_total_water_0"][0, :] == 0.0)
+ assert torch.all(output["specific_total_water_1"][1, :] == 0.0)
+ assert torch.all(output["specific_total_water_1"][0, :] != 0.0)
+
+
+def test_masking_no_3d_masking():
+ config = MaskingConfig(
+ mask_name="surface_mask",
+ mask_value=0,
+ fill_value=0.0,
+ )
+ mask = config.build()
+ stacker = Stacker({"surface_pressure": ["PRESsfc", "PS"]})
+ output = mask(stacker, _DATA, _MASK_DATA)
+ assert output["PRESsfc"][1, 1] == 0.0
+ assert output["PRESsfc"][0, 1] != 0.0
+ assert torch.all(output["specific_total_water_0"][0, :] != 0.0)
+ assert torch.all(output["specific_total_water_1"][1, :] != 0.0)
+ assert torch.all(output["specific_total_water_1"][0, :] != 0.0)
+
+
+def test_masking_no_surface_masking():
+ config = MaskingConfig(
+ mask_name="mask",
+ mask_value=0,
+ fill_value=0.0,
+ )
+ mask = config.build()
+ stacker = Stacker({"specific_total_water": ["specific_total_water_"]})
+ output = mask(stacker, _DATA, _MASK_DATA)
+ assert output["PRESsfc"][1, 1] != 0.0
+ assert output["PRESsfc"][0, 1] != 0.0
+ assert torch.all(output["specific_total_water_0"][0, :] == 0.0)
+ assert torch.all(output["specific_total_water_1"][1, :] == 0.0)
+ assert torch.all(output["specific_total_water_1"][0, :] != 0.0)
+
+
+def test_masking_missing_2d_mask():
+ config = MaskingConfig(
+ mask_name="mask",
+ mask_value=0,
+ fill_value=0.0,
+ )
+ mask = config.build()
+ stacker = Stacker(
+ {
+ "surface_pressure": ["PRESsfc", "PS"],
+ "specific_total_water": ["specific_total_water_"],
+ }
+ )
+ with pytest.raises(RuntimeError) as err:
+ _ = mask(stacker, _DATA, _MASK_DATA)
+ assert "surface_mask_name is None" in str(err.value)
+ assert "surface_pressure" in str(err.value)
diff --git a/fme/fme/core/test_metrics.py b/fme/fme/core/test_metrics.py
index fdf29dc..71b9077 100644
--- a/fme/fme/core/test_metrics.py
+++ b/fme/fme/core/test_metrics.py
@@ -4,12 +4,12 @@
import torch_harmonics
import fme
+from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.metrics import (
net_surface_energy_flux,
quantile,
spherical_power_spectrum,
surface_pressure_due_to_dry_air,
- vertical_integral,
)
@@ -255,22 +255,13 @@ def test_gradient_magnitude_percent_diff():
torch.testing.assert_close(percent_diff, -100 * torch.ones((5, 2)))
-def test_vertical_integral_shape():
- nlat, nlon, nz = 4, 8, 3
- water = torch.rand(nlat, nlon, nz)
- pressure = torch.rand(nlat, nlon)
- ak, bk = torch.arange(nz + 1), torch.arange(nz + 1)
- water_path = vertical_integral(water, pressure, ak, bk)
- assert water_path.shape == (nlat, nlon)
-
-
def test_dry_air_shapes():
nlat, nlon, nz = 4, 8, 3
water = torch.rand(nlat, nlon, nz)
pressure = torch.rand(nlat, nlon)
ak, bk = torch.arange(nz + 1), torch.arange(nz + 1)
-
- dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ dry_air = surface_pressure_due_to_dry_air(water, pressure, coords)
assert dry_air.shape == (nlat, nlon)
@@ -286,8 +277,8 @@ def test_single_level_dry_air_no_water():
water = torch.zeros(nlat, nlon, nz)
pressure = torch.rand(nlat, nlon)
ak, bk = single_level_ak_bk()
-
- dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ dry_air = surface_pressure_due_to_dry_air(water, pressure, coords)
np.testing.assert_allclose(dry_air.cpu().numpy(), pressure.cpu().numpy())
@@ -297,8 +288,8 @@ def test_single_level_dry_air_all_water():
water = torch.ones(nlat, nlon, nz)
pressure = torch.rand(nlat, nlon)
ak, bk = single_level_ak_bk()
-
- dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ dry_air = surface_pressure_due_to_dry_air(water, pressure, coords)
np.testing.assert_almost_equal(dry_air.cpu().numpy(), 0.0, decimal=6)
@@ -309,8 +300,8 @@ def test_single_level_dry_air_some_water():
pressure = torch.rand(nlat, nlon)
ak, bk = single_level_ak_bk()
target_dry_air = pressure * (1.0 - water[:, :, 0])
-
- dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk)
+ coords = HybridSigmaPressureCoordinate(ak, bk)
+ dry_air = surface_pressure_due_to_dry_air(water, pressure, coords)
np.testing.assert_allclose(
dry_air.cpu().numpy(), target_dry_air.cpu().numpy(), rtol=1e-5
)
diff --git a/fme/fme/core/test_normalizer.py b/fme/fme/core/test_normalizer.py
index 5aaeca4..06d2edb 100644
--- a/fme/fme/core/test_normalizer.py
+++ b/fme/fme/core/test_normalizer.py
@@ -1,8 +1,8 @@
import pytest
import torch
+from fme.core.device import move_tensordict_to_device
from fme.core.normalizer import NormalizationConfig, StandardNormalizer
-from fme.core.typing_ import TensorDict
def test_normalize_depends_on_mean():
@@ -48,35 +48,35 @@ def test_denormalize_depends_on_std():
def test_normalize_and_denormalize_random_tensor():
torch.manual_seed(0)
# randomly set means and stds
- means = {"a": torch.randn(1), "b": torch.randn(1)}
- stds = {"a": torch.randn(1), "b": torch.randn(1)}
+ means = move_tensordict_to_device({"a": torch.randn(1), "b": torch.randn(1)})
+ stds = move_tensordict_to_device({"a": torch.randn(1), "b": torch.randn(1)})
normalizer = StandardNormalizer(means=means, stds=stds)
- tensors = {"a": torch.randn(10), "b": torch.randn(10)}
+ tensors = move_tensordict_to_device({"a": torch.randn(10), "b": torch.randn(10)})
denormalized = normalizer.denormalize(normalizer.normalize(tensors))
assert torch.allclose(denormalized["a"], tensors["a"])
assert torch.allclose(denormalized["b"], tensors["b"])
-def test_normalization_config_exclude_names():
- torch.manual_seed(0)
+def test_missing_normalization_build_raises_error():
normalization = NormalizationConfig(
means={"a": 1.0, "b": 2.0},
stds={"a": 1.0, "b": 1.0},
- exclude_names=["c"],
)
- normalizer = normalization.build(["a", "b", "c"])
- tensors = {"a": torch.randn(10), "b": torch.randn(10), "c": torch.randn(10)}
- normalized: TensorDict = normalizer.normalize(tensors)
- denormalized = normalizer.denormalize(normalized)
- assert torch.all(normalized["c"] == tensors["c"])
- assert torch.all(denormalized["c"] == tensors["c"])
+ all_names = ["a", "b", "c"]
+ with pytest.raises(KeyError):
+ normalization.build(all_names)
-def test_missing_normalization_raises_error():
+def test_tensors_with_missing_normalization_stats_get_filtered():
normalization = NormalizationConfig(
means={"a": 1.0, "b": 2.0},
stds={"a": 1.0, "b": 1.0},
- )
- all_names = ["a", "b", "c"]
- with pytest.raises(KeyError):
- normalization.build(all_names)
+ ).build(["a", "b"])
+ sample_input = {"a": torch.zeros(1), "b": torch.zeros(1), "c": torch.zeros(1)}
+ sample_input = move_tensordict_to_device(sample_input)
+
+ normalized = normalization.normalize(sample_input)
+ assert "c" not in normalized
+
+ denormalized = normalization.denormalize(sample_input)
+ assert "c" not in denormalized
diff --git a/fme/fme/core/test_ocean.py b/fme/fme/core/test_ocean.py
index 0fbd9da..26de5ec 100644
--- a/fme/fme/core/test_ocean.py
+++ b/fme/fme/core/test_ocean.py
@@ -18,7 +18,7 @@ def test_ocean_prescribed():
target_data = {"sst": torch.tensor([22.0, 25.0]), "of": torch.tensor([0.2, 0.8])}
input_data = {"sst": torch.tensor([20.0, 21.0]), "foo": torch.tensor([1, 2])}
gen_data = {"sst": torch.tensor([23.0, 26.0]), "foo": torch.tensor([2, 3])}
- output_data = ocean(target_data, input_data, gen_data)
+ output_data = ocean(input_data, gen_data, target_data)
expected_output = {"sst": torch.tensor([23.0, 25.0]), "foo": torch.tensor([2, 3])}
assert set(output_data) == set(expected_output)
for name in output_data:
@@ -52,7 +52,7 @@ def test_ocean_slab():
}
input_data = {"sst": torch.tensor([20.0])}
gen_data = {**fluxes, "sst": torch.tensor([25.0])}
- output_data = ocean(target_data, input_data, gen_data)
+ output_data = ocean(input_data, gen_data, target_data)
expected_sst_tendency = mixed_layer_temperature_tendency(
expected_net_surface_energy_flux, target_data["qf"], target_data["mld"]
)
diff --git a/fme/fme/core/test_parameter_init.py b/fme/fme/core/test_parameter_init.py
index 084b4b4..46857dc 100644
--- a/fme/fme/core/test_parameter_init.py
+++ b/fme/fme/core/test_parameter_init.py
@@ -9,11 +9,11 @@
import pytest
import torch
+from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig
from fme.core import parameter_init
-from fme.core.data_loading.data_typing import SigmaCoordinates
+from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.device import get_device
-from fme.core.normalizer import FromStateNormalizer
-from fme.core.stepper import SingleModuleStepper, SingleModuleStepperConfig
+from fme.core.gridded_ops import LatLonOperations
from fme.core.typing_ import TensorMapping
from fme.core.wildcard import wildcard_match
@@ -32,22 +32,20 @@ def test_builder_with_weights_loads_same_state(tmpdir):
"builder": sfno_config_data,
"in_names": ["x"],
"out_names": ["x"],
- "normalization": FromStateNormalizer(
- state={
- "means": {"x": np.random.randn(1)},
- "stds": {"x": np.random.randn(1)},
- }
- ),
+ "normalization": {
+ "means": {"x": np.random.randn(1).item()},
+ "stds": {"x": np.random.randn(1).item()},
+ },
}
area = torch.ones((1, 16, 32)).to(get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to(
- get_device()
- )
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ ).to(get_device())
stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data)
stepper = stepper_config.get_stepper(
img_shape=(16, 32),
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
torch.save(
@@ -64,19 +62,17 @@ def test_builder_with_weights_loads_same_state(tmpdir):
"parameter_init": parameter_init_config,
"in_names": ["x"],
"out_names": ["x"],
- "normalization": FromStateNormalizer(
- state={
- "means": {"x": np.random.randn(1)},
- "stds": {"x": np.random.randn(1)},
- }
- ),
+ "normalization": {
+ "means": {"x": np.random.randn(1).item()},
+ "stds": {"x": np.random.randn(1).item()},
+ },
}
with_builder_stepper = SingleModuleStepperConfig.from_state(
with_builder_stepper_config_data
).get_stepper(
img_shape=(16, 32),
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
assert_same_state(
@@ -148,7 +144,7 @@ def test_builder_with_weights_sfno_init(
"""
Integration test for the BuilderWithWeights stepper with a SFNO.
"""
- with_builder_stepper_config_data, area, sigma_coordinates, stepper = get_config(
+ with_builder_stepper_config_data, area, vertical_coordinate, stepper = get_config(
loaded_shape, extra_built_layer, tmpdir
)
if expect_exception:
@@ -157,8 +153,8 @@ def test_builder_with_weights_sfno_init(
with_builder_stepper_config_data
).get_stepper(
img_shape=built_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
else:
@@ -166,8 +162,8 @@ def test_builder_with_weights_sfno_init(
with_builder_stepper_config_data
).get_stepper(
img_shape=built_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
if extra_built_layer:
@@ -203,22 +199,20 @@ def get_config(
"builder": sfno_config_data,
"in_names": ["x"],
"out_names": ["x"],
- "normalization": FromStateNormalizer(
- state={
- "means": {"x": np.random.randn(1)},
- "stds": {"x": np.random.randn(1)},
- }
- ),
+ "normalization": {
+ "means": {"x": np.random.randn(1).item()},
+ "stds": {"x": np.random.randn(1).item()},
+ },
}
area = torch.ones((1, 16, 32)).to(get_device())
- sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to(
- get_device()
- )
+ vertical_coordinate = HybridSigmaPressureCoordinate(
+ ak=torch.arange(7), bk=torch.arange(7)
+ ).to(get_device())
stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data)
stepper = stepper_config.get_stepper(
img_shape=loaded_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
built_sfno_config_data = copy.deepcopy(sfno_config_data)
@@ -238,35 +232,31 @@ def get_config(
"parameter_init": parameter_init_config,
"in_names": ["x"],
"out_names": ["x"],
- "normalization": FromStateNormalizer(
- state={
- "means": {"x": np.random.randn(1)},
- "stds": {"x": np.random.randn(1)},
- }
- ),
+ "normalization": {
+ "means": {"x": np.random.randn(1).item()},
+ "stds": {"x": np.random.randn(1).item()},
+ },
}
- return with_builder_stepper_config_data, area, sigma_coordinates, stepper
+ return with_builder_stepper_config_data, area, vertical_coordinate, stepper
def test_with_weights_saved_stepper_does_not_need_untuned_weights(tmpdir):
img_shape = (16, 32)
- with_builder_stepper_config_data, area, sigma_coordinates, stepper = get_config(
+ with_builder_stepper_config_data, area, vertical_coordinate, stepper = get_config(
loaded_shape=img_shape, extra_built_layer=False, tmpdir=tmpdir
)
with_builder_stepper = SingleModuleStepperConfig.from_state(
with_builder_stepper_config_data
).get_stepper(
img_shape=img_shape,
- area=area,
- sigma_coordinates=sigma_coordinates,
+ gridded_operations=LatLonOperations(area),
+ vertical_coordinate=vertical_coordinate,
timestep=TIMESTEP,
)
stepper_state = with_builder_stepper.get_state()
# should be able to initialize stepper from its state without the untuned weights
(tmpdir / "weights.ckpt").remove()
- stepper = SingleModuleStepper.from_state(
- stepper_state, area=area, sigma_coordinates=sigma_coordinates
- )
+ stepper = SingleModuleStepper.from_state(stepper_state)
assert isinstance(stepper, SingleModuleStepper)
diff --git a/fme/fme/core/test_registry.py b/fme/fme/core/test_registry.py
deleted file mode 100644
index 006febf..0000000
--- a/fme/fme/core/test_registry.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import dataclasses
-from typing import Iterable, Tuple
-
-import pytest
-import torch
-
-from fme.core import registry
-
-
-class MockModule(torch.nn.Module):
- def __init__(self, param_shapes: Iterable[Tuple[int, ...]]):
- super().__init__()
- for i, shape in enumerate(param_shapes):
- setattr(self, f"param{i}", torch.nn.Parameter(torch.randn(shape)))
-
-
-@dataclasses.dataclass
-class MockModuleBuilder:
- param_shapes: Iterable[Tuple[int, ...]]
-
- def build(self, n_in_channels, n_out_channels, img_shape):
- return MockModule(self.param_shapes)
-
- @classmethod
- def from_state(self, state):
- return MockModuleBuilder(state["param_shapes"])
-
- def get_state(self):
- return {
- "param_shapes": self.param_shapes,
- }
-
-
-def test_register():
- """Make sure that the registry is working as expected."""
- original_registry = registry.NET_REGISTRY
- try:
- registry.NET_REGISTRY = {}
- with pytest.raises(ValueError):
- selector = registry.ModuleSelector(
- type="mock",
- config={"param_shapes": [(1, 2, 3)]},
- )
- registry.register("mock")(MockModuleBuilder)
- selector = registry.ModuleSelector(
- type="mock",
- config={"param_shapes": [(1, 2, 3)]},
- )
- module = selector.build(1, 1, (16, 32))
- assert isinstance(module, MockModule)
- finally:
- registry.NET_REGISTRY = original_registry
diff --git a/fme/fme/core/test_stacker.py b/fme/fme/core/test_stacker.py
new file mode 100644
index 0000000..25cf8be
--- /dev/null
+++ b/fme/fme/core/test_stacker.py
@@ -0,0 +1,139 @@
+import pytest
+import torch
+
+from .stacker import Stacker, natural_sort, unstack
+
+
+@pytest.mark.parametrize(
+ "names, sorted_names",
+ [
+ (
+ ["a_1", "b_1", "c_1", "a_2"],
+ [
+ "a_1",
+ "a_2",
+ "b_1",
+ "c_1",
+ ],
+ ),
+ (
+ [
+ "a_0",
+ "a_1",
+ "a_12",
+ "a_2",
+ ],
+ [
+ "a_0",
+ "a_1",
+ "a_2",
+ "a_12",
+ ],
+ ),
+ (
+ [
+ "a_0001",
+ "a_0012",
+ "a_0002",
+ ],
+ [
+ "a_0001",
+ "a_0002",
+ "a_0012",
+ ],
+ ),
+ (
+ [
+ "ab1",
+ "aa10",
+ "aa2",
+ ],
+ ["aa2", "aa10", "ab1"],
+ ),
+ ],
+)
+def test_natural_sort(names, sorted_names):
+ assert natural_sort(names) == sorted_names
+
+
+_PREFIX_MAPS = [{"a": ["aa_", "ab_"], "b": ["b_", "bb_"], "c": ["ca", "cb"]}]
+
+_DATA_NAMES = [
+ ["aa_0", "aa_1", "b_00", "b_01", "b", "ca"],
+ ["ab_0", "ab_1", "bb_1", "bb_2", "cb"],
+ ["ab_0", "b_", "cc"],
+]
+
+
+@pytest.mark.parametrize(
+ "prefix_map, data_names, expected_level_names",
+ [
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[0],
+ {"a": ["aa_0", "aa_1"], "b": ["b_00", "b_01"], "c": ["ca"]},
+ ),
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[1],
+ {"a": ["ab_0", "ab_1"], "b": ValueError, "c": ["cb"]},
+ ),
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[2],
+ {"a": ["ab_0"], "b": ["b_"], "c": KeyError},
+ ),
+ ],
+)
+def test_get_all_level_names(prefix_map, data_names, expected_level_names):
+ data = {name: torch.zeros(2, 2) for name in data_names}
+ stacker = Stacker(prefix_map)
+ for standard_name in expected_level_names.keys():
+ if isinstance(expected_level_names[standard_name], type):
+ with pytest.raises(expected_level_names[standard_name]):
+ stacker.get_all_level_names(standard_name, data)
+ else:
+ assert (
+ stacker.get_all_level_names(standard_name, data)
+ == expected_level_names[standard_name]
+ )
+
+
+@pytest.mark.parametrize(
+ "prefix_map, data_names, expected_level_names",
+ [
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[0],
+ {"a": ["aa_0", "aa_1"], "b": ["b_00", "b_01"], "c": ["ca"]},
+ ),
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[1],
+ {"a": ["ab_0", "ab_1"], "c": ["cb"]},
+ ),
+ (
+ _PREFIX_MAPS[0],
+ _DATA_NAMES[2],
+ {"a": ["ab_0"], "b": ["b_"]},
+ ),
+ ],
+)
+def test_stack_unstack(prefix_map, data_names, expected_level_names):
+ torch.manual_seed(0)
+ data = {name: torch.rand(2, 2) for name in data_names}
+ stacker = Stacker(prefix_map)
+ for standard_name in expected_level_names.keys():
+ level_names = expected_level_names[standard_name]
+ if len(level_names) == 1:
+ expected_stacked = data[level_names[0]].unsqueeze(-1)
+ else:
+ expected_stacked = torch.stack([data[name] for name in level_names], dim=-1)
+ stacked = stacker(standard_name, data)
+ assert torch.allclose(stacked, expected_stacked)
+ unstacked = unstack(
+ stacked, names=stacker.get_all_level_names(standard_name, data)
+ )
+ for name in level_names:
+ assert name in unstacked
+ assert torch.allclose(unstacked[name], data[name])
diff --git a/fme/fme/core/test_timing.py b/fme/fme/core/test_timing.py
new file mode 100644
index 0000000..b5e78cf
--- /dev/null
+++ b/fme/fme/core/test_timing.py
@@ -0,0 +1,199 @@
+import logging
+import time
+
+import numpy as np
+import pytest
+
+from fme.core.timing import CumulativeTimer, GlobalTimer
+
+
+def test_CumulativeTimer():
+ category = "foo"
+ cumulative_timer = CumulativeTimer(category)
+
+ cumulative_timer.start()
+ time.sleep(0.01)
+ cumulative_timer.stop()
+
+ time.sleep(0.01)
+
+ cumulative_timer.start()
+ time.sleep(0.01)
+ cumulative_timer.stop()
+
+ assert cumulative_timer.duration == pytest.approx(0.02, abs=0.005)
+
+
+def test_CumulativeTimer_start_error():
+ category = "foo"
+ cumulative_timer = CumulativeTimer(category)
+ cumulative_timer.start()
+ with pytest.raises(RuntimeError, match=f"timer {category!r} is already running"):
+ cumulative_timer.start()
+
+
+def test_CumulativeTimer_stop_error():
+ category = "foo"
+ cumulative_timer = CumulativeTimer(category)
+ with pytest.raises(
+ RuntimeError, match=f"must call start for timer {category!r} before stop"
+ ):
+ cumulative_timer.stop()
+
+
+def test_CumulativeTimer_duration_error():
+ category = "foo"
+ cumulative_timer = CumulativeTimer(category)
+ cumulative_timer.start()
+ with pytest.raises(RuntimeError, match=f"timer {category!r} is still running"):
+ cumulative_timer.duration
+
+
+def exercise_active_timer():
+ timer = GlobalTimer.get_instance()
+
+ timer.start("foo")
+ time.sleep(0.01)
+ timer.stop()
+
+ timer = GlobalTimer.get_instance()
+
+ timer.start("bar")
+ time.sleep(0.02)
+ timer.stop()
+
+ assert timer.get_duration("foo") == pytest.approx(0.01, abs=0.005)
+ assert timer.get_duration("bar") == pytest.approx(0.02, abs=0.005)
+
+ with pytest.raises(KeyError):
+ timer.get_duration("baz")
+
+ durations = timer.get_durations()
+ assert durations["foo"] == pytest.approx(0.01, abs=0.005)
+ assert durations["bar"] == pytest.approx(0.02, abs=0.005)
+
+
+def test_GlobalTimer():
+ with GlobalTimer():
+ exercise_active_timer()
+
+ # Check that timer is reset within new context.
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ assert timer.get_durations() == {}
+
+
+def test_GlobalTimer_resets_after_exception():
+ with pytest.raises(ValueError):
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ timer.start("foo")
+ raise ValueError()
+
+ # Check that the context manager clears the state of the timer after an
+ # exception. If it were not clear, starting the timer for "foo" would raise
+ # an error.
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ timer.start("foo")
+ timer.stop()
+
+
+def test_GlobalTimer_multiple_context_error():
+ with pytest.raises(RuntimeError, match="GlobalTimer is currently in use"):
+ with GlobalTimer(), GlobalTimer():
+ pass
+
+
+def test_inactive_GlobalTimer_warning():
+ with pytest.warns(UserWarning, match=r"inactive"):
+ GlobalTimer.get_instance()
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_inactive_GlobalTimer_start():
+ timer = GlobalTimer.get_instance()
+ timer.start("foo")
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_inactive_GlobalTimer_stop():
+ timer = GlobalTimer.get_instance()
+ timer.stop()
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_inactive_GlobalTimer_get_duration():
+ timer = GlobalTimer.get_instance()
+ result = timer.get_duration("foo")
+ assert np.isnan(result)
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_inactive_GlobalTimer_get_durations():
+ timer = GlobalTimer.get_instance()
+ result = timer.get_durations()
+ assert result == {}
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_inactive_GlobalTimer_log_durations(caplog):
+ timer = GlobalTimer.get_instance()
+ with caplog.at_level(logging.INFO):
+ timer.log_durations()
+ assert len(caplog.records) == 0
+
+
+@pytest.mark.filterwarnings("ignore:The GlobalTimer")
+def test_GlobalTimer_inactive_then_active():
+ # Make sure we can instantiate an inactive timer, but then still create and
+ # use an active timer later.
+ GlobalTimer.get_instance()
+
+ with GlobalTimer():
+ exercise_active_timer()
+
+
+def test_GlobalTimer_context():
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ with timer.context("foo"):
+ time.sleep(0.01)
+ assert timer.get_duration("foo") > 0.01
+
+
+def test_GlobalTimer_context_with_exception():
+ with pytest.raises(ValueError):
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ with timer.context("foo"):
+ time.sleep(0.01)
+ raise ValueError()
+ assert timer.get_duration("foo") > 0.01
+
+
+def test_GlobalTimer_single_inner_timer():
+ with pytest.raises(
+ RuntimeError, match="GlobalTimer already has an active inner timer"
+ ):
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ with timer.context("foo"):
+ with timer.context("bar"):
+ pass
+
+
+def test_GlobalTimer_nested_outer_context():
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ with timer.outer_context("foo"):
+ with timer.context("bar"):
+ pass
+
+
+def test_GlobalTimer_double_nested_outer_context():
+ with GlobalTimer():
+ timer = GlobalTimer.get_instance()
+ with timer.outer_context("foo"):
+ with timer.outer_context("bar"):
+ pass
diff --git a/fme/fme/core/testing/__init__.py b/fme/fme/core/testing/__init__.py
index 8de3120..6852c92 100644
--- a/fme/fme/core/testing/__init__.py
+++ b/fme/fme/core/testing/__init__.py
@@ -1,10 +1,2 @@
from .distributed import mock_distributed
-from .fv3gfs_data import (
- DimSizes,
- FV3GFSData,
- MonthlyReferenceData,
- StatsData,
- save_2d_netcdf,
- save_scalar_netcdf,
-)
from .wandb import mock_wandb
diff --git a/fme/fme/core/testing/wandb.py b/fme/fme/core/testing/wandb.py
index 1523a9c..e806854 100644
--- a/fme/fme/core/testing/wandb.py
+++ b/fme/fme/core/testing/wandb.py
@@ -1,6 +1,6 @@
import collections
import contextlib
-from typing import Any, Dict, Mapping
+from typing import Any, Dict, List, Mapping
from fme.core import wandb
from fme.core.distributed import Distributed
@@ -11,6 +11,7 @@ def __init__(self):
self._enabled = False
self._configured = False
self._logs: Dict[int, Dict[str, Any]] = collections.defaultdict(dict)
+ self._last_step = 0
def configure(self, log_to_wandb: bool):
dist = Distributed.get_instance()
@@ -31,12 +32,24 @@ def watch(self, modules):
pass
def log(self, data: Mapping[str, Any], step: int, sleep=None):
+ if step < self._last_step:
+ raise ValueError(
+ f"step {step} is less than last step {self._last_step}, "
+ "steps must be logged in order"
+ )
+ self._last_step = step
# sleep arg is ignored since we don't want to sleep in tests
if self._enabled:
self._logs[step].update(data)
- def get_logs(self) -> Dict[int, Dict[str, Any]]:
- return self._logs
+ def get_logs(self) -> List[Dict[str, Any]]:
+ if len(self._logs) == 0:
+ return []
+ n_logs = max(self._logs.keys())
+ return_value: List[Dict[str, Any]] = [dict() for _ in range(n_logs + 1)]
+ for step, log in self._logs.items():
+ return_value[step] = log
+ return return_value
def clean_wandb_dir(self, experiment_dir: str):
pass
diff --git a/fme/fme/core/timing.py b/fme/fme/core/timing.py
new file mode 100644
index 0000000..1f0f670
--- /dev/null
+++ b/fme/fme/core/timing.py
@@ -0,0 +1,168 @@
+import contextlib
+import logging
+import time
+import warnings
+from typing import Dict, Optional
+
+import numpy as np
+
+singleton: Optional["GlobalTimer"] = None
+
+
+INACTIVE_WARNING_MESSAGE = (
+ "The GlobalTimer is currently inactive; therefore no timing information "
+ "will be recorded. To activate it, wrap your code within the GlobalTimer() "
+ "context."
+)
+
+
+class CumulativeTimer:
+ def __init__(self, category):
+ self._duration = 0.0
+ self._start_time = None
+ self._category = category
+
+ def start(self):
+ if self._start_time is not None:
+ raise RuntimeError(f"timer {self._category!r} is already running")
+ self._start_time = time.time()
+
+ def stop(self):
+ if self._start_time is None:
+ raise RuntimeError(
+ f"must call start for timer {self._category!r} before stop"
+ )
+ self._duration += time.time() - self._start_time
+ self._start_time = None
+
+ @property
+ def duration(self) -> float:
+ if self._start_time is not None:
+ raise RuntimeError(f"timer {self._category!r} is still running")
+ return self._duration
+
+
+class GlobalTimer:
+ """
+ A singleton class to make timing inference code easier.
+ """
+
+ @classmethod
+ def get_instance(cls) -> "GlobalTimer":
+ """
+ Get the singleton instance of the GlobalTimer class.
+ """
+ global singleton
+ if singleton is None:
+ singleton = cls()
+ warnings.warn(INACTIVE_WARNING_MESSAGE)
+ return singleton
+
+ @classmethod
+ def __enter__(cls):
+ global singleton
+ if singleton is None:
+ singleton = cls()
+ if singleton._active:
+ raise RuntimeError("GlobalTimer is currently in use in another context")
+ singleton._active = True
+
+ @classmethod
+ def __exit__(cls, type, value, traceback):
+ global singleton
+ singleton = None
+
+ def __init__(self):
+ self._timers: Dict[str, CumulativeTimer] = {}
+ self._active = False
+ self._current_category: Optional[str] = None
+
+ def outer_context(self, category: str) -> contextlib.AbstractContextManager:
+ """
+ Context manager for timing a block of code.
+
+ May be active at the same time as other timers.
+ """
+
+ @contextlib.contextmanager
+ def timer_context():
+ self.start_outer(category)
+ try:
+ yield
+ finally:
+ self.stop_outer(category)
+
+ return timer_context()
+
+ def context(self, category: str) -> contextlib.AbstractContextManager:
+ """
+ Context manager for timing a block of code.
+
+ Only one inner timer can be active at a time.
+ """
+
+ @contextlib.contextmanager
+ def timer_context():
+ self.start(category)
+ try:
+ yield
+ finally:
+ self.stop()
+
+ return timer_context()
+
+ def start(self, category: str):
+ """
+ Start an inner timer for the given category.
+
+ Only one inner timer can be active at a time.
+ """
+ if self._current_category is not None:
+ raise RuntimeError(
+ "GlobalTimer already has an active inner timer, "
+ f"{self._current_category}"
+ )
+ self.start_outer(category)
+ self._current_category = category
+
+ def start_outer(self, category: str):
+ """
+ Start a timer for the given category.
+
+ May be active at the same time as other timers.
+ """
+ if self._active:
+ if category not in self._timers:
+ self._timers[category] = CumulativeTimer(category)
+ self._timers[category].start()
+
+ def stop(self):
+ """
+ Stop the currently active inner timer.
+ """
+ if self._current_category is None:
+ raise RuntimeError("GlobalTimer does not have a running timer")
+ self.stop_outer(self._current_category)
+ self._current_category = None
+
+ def stop_outer(self, category: str):
+ """
+ Stop the timer for the given category.
+
+ Does not change the currently active inner timer.
+ """
+ if self._active:
+ self._timers[category].stop()
+
+ def get_duration(self, category: str) -> float:
+ if self._active:
+ return self._timers[category].duration
+ else:
+ return np.nan
+
+ def get_durations(self) -> Dict[str, float]:
+ return {category: timer.duration for category, timer in self._timers.items()}
+
+ def log_durations(self):
+ for name, duration in self.get_durations().items():
+ logging.info(f"{name} duration: {duration:.2f}s")
diff --git a/fme/fme/core/typing_.py b/fme/fme/core/typing_.py
index dcfac10..140691c 100644
--- a/fme/fme/core/typing_.py
+++ b/fme/fme/core/typing_.py
@@ -1,6 +1,29 @@
-from typing import Dict, Mapping
+import dataclasses
+from typing import Dict, Mapping, Optional
import torch
TensorMapping = Mapping[str, torch.Tensor]
TensorDict = Dict[str, torch.Tensor]
+
+
+@dataclasses.dataclass
+class Slice:
+ """
+ Configuration of a python `slice` built-in.
+
+ Required because `slice` cannot be initialized directly by dacite.
+
+ Parameters:
+ start: Start index of the slice.
+ stop: Stop index of the slice.
+ step: Step of the slice.
+ """
+
+ start: Optional[int] = None
+ stop: Optional[int] = None
+ step: Optional[int] = None
+
+ @property
+ def slice(self) -> slice:
+ return slice(self.start, self.stop, self.step)
diff --git a/fme/fme/core/wandb.py b/fme/fme/core/wandb.py
index d7d0d6f..703b8ae 100644
--- a/fme/fme/core/wandb.py
+++ b/fme/fme/core/wandb.py
@@ -117,12 +117,13 @@ def configure(self, log_to_wandb: bool):
self._configured = True
def init(self, **kwargs):
- """kwargs are passed to wandb.init"""
+ """Kwargs are passed to wandb.init."""
if not self._configured:
raise RuntimeError(
"must call WandB.configure before WandB init can be called"
)
if self._enabled:
+ wandb.require("core")
wandb.init(**kwargs)
def watch(self, modules):
diff --git a/fme/fme/core/weight_ops.py b/fme/fme/core/weight_ops.py
index d1d1174..8f3ce9f 100644
--- a/fme/fme/core/weight_ops.py
+++ b/fme/fme/core/weight_ops.py
@@ -25,7 +25,7 @@ class CopyWeightsConfig:
All parameters must be covered by either the include or exclude list,
but not both.
- Attributes:
+ Parameters:
include: list of wildcard patterns to overwrite
exclude: list of wildcard patterns to exclude from overwriting
"""
diff --git a/fme/fme/core/winds.py b/fme/fme/core/winds.py
index 78a55d4..1513bcc 100644
--- a/fme/fme/core/winds.py
+++ b/fme/fme/core/winds.py
@@ -26,7 +26,6 @@ def u_v_to_x_y_z_wind(
wy: y wind component
wz: z wind component
"""
-
# for a graphical proof of the equations used here, see
# https://github.com/ai2cm/full-model/pull/355#issuecomment-1729773301
diff --git a/fme/fme/require_gpu.py b/fme/fme/require_gpu.py
new file mode 100644
index 0000000..bac46cb
--- /dev/null
+++ b/fme/fme/require_gpu.py
@@ -0,0 +1,9 @@
+import fme
+
+"""
+Manually triggered for CI tests on GPU so that tests do not
+default to CPU if driver issues prevent use of CUDA.
+"""
+device = str(fme.get_device())
+print(f"Device: {device}")
+assert device.startswith("cuda")
diff --git a/fme/fme/sht_fix.py b/fme/fme/sht_fix.py
index cd5fa6b..2e8da49 100644
--- a/fme/fme/sht_fix.py
+++ b/fme/fme/sht_fix.py
@@ -10,8 +10,6 @@
[*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py
"""
-USE_FIX = True
-
# coding=utf-8
@@ -49,12 +47,8 @@
import torch.nn as nn
import torch.fft
-if USE_FIX:
- from torch_harmonics.quadrature import *
- from torch_harmonics.legendre import *
-else:
- from .quadrature import *
- from .legendre import *
+from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
+from torch_harmonics.legendre import precompute_legpoly, precompute_dlegpoly
class RealSHT(nn.Module):
@@ -115,10 +109,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
weights = torch.einsum('mlk,k->mlk', pct, weights)
# remember quadrature weights
- if USE_FIX:
- self.weights = weights.float()
- else:
- self.register_buffer('weights', weights, persistent=False)
+ self.weights = weights.float()
def extra_repr(self):
"""
@@ -144,8 +135,7 @@ def forward(self, x: torch.Tensor):
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction
- if USE_FIX:
- self.weights = self.weights.to(x.device)
+ self.weights = self.weights.to(x.device)
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights)
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights)
x = torch.view_as_complex(xout)
@@ -197,10 +187,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# register buffer
- if USE_FIX:
- self.pct = pct.float()
- else:
- self.register_buffer('pct', pct, persistent=False)
+ self.pct = pct.float()
def extra_repr(self):
"""
@@ -216,8 +203,7 @@ def forward(self, x: torch.Tensor):
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
- if USE_FIX:
- self.pct = self.pct.to(x.device)
+ self.pct = self.pct.to(x.device)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1)
@@ -291,10 +277,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
weights[1] = -1 * weights[1]
# remember quadrature weights
- if USE_FIX:
- self.weights = weights.float()
- else:
- self.register_buffer('weights', weights, persistent=False)
+ self.weights = weights.float()
def extra_repr(self):
"""
@@ -380,10 +363,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# register weights
- if USE_FIX:
- self.dpct = dpct.float()
- else:
- self.register_buffer('dpct', dpct, persistent=False)
+ self.dpct = dpct.float()
def extra_repr(self):
"""
@@ -401,8 +381,7 @@ def forward(self, x: torch.Tensor):
# contraction - spheroidal component
# real component
- if USE_FIX:
- self.dpct = self.dpct.to(x.device)
+ self.dpct = self.dpct.to(x.device)
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1])
# iamg component
diff --git a/fme/pyproject.toml b/fme/pyproject.toml
index 165f261..0ce45e9 100644
--- a/fme/pyproject.toml
+++ b/fme/pyproject.toml
@@ -33,3 +33,22 @@ optional-dependencies.deploy = { file = "deploy-requirements.txt" }
[tool.setuptools.packages]
find = {}
+
+[tool.uv]
+cache-keys = [
+ { file = "requirements.txt" },
+ { file = "dev-requirements.txt" },
+ { file = "docs/requirements.txt" },
+]
+
+[tool.ruff.lint]
+select = ["D", "E", "F", "I", "W"]
+ignore = ["D1", "D200", "D205", "D212", "E203", "W293", "F541", "E402"]
+
+[tool.ruff.lint.per-file-ignores]
+"*/__init__.py" = ["F401"]
+"scripts/*" = ["D"]
+"test_*.py" = ["D"]
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
diff --git a/fme/pytest.ini b/fme/pytest.ini
deleted file mode 100644
index 4c45e1d..0000000
--- a/fme/pytest.ini
+++ /dev/null
@@ -1,3 +0,0 @@
-[pytest]
-markers =
- requires_gpu: these tests require a GPU to run
\ No newline at end of file
diff --git a/fme/requirements.txt b/fme/requirements.txt
index 587ce0b..e1ed2a7 100644
--- a/fme/requirements.txt
+++ b/fme/requirements.txt
@@ -1,9 +1,9 @@
h5py
imageio<=2.27.0
-moviepy
+moviepy<2.0.0 # should be able to relax this after wandb updates past 0.18.7
netcdf4
numpy<2
-wandb
+wandb[media]
tensorly
tensorly-torch
xarray
@@ -17,4 +17,4 @@ plotly
matplotlib
dask
astropy-healpix
-pandas
+pandas
\ No newline at end of file
diff --git a/scripts/data_process/Makefile b/scripts/data_process/Makefile
new file mode 100644
index 0000000..069b50f
--- /dev/null
+++ b/scripts/data_process/Makefile
@@ -0,0 +1,246 @@
+# The dependencies of the scripts below (where not containerized) are installed in the "fv3net" conda environment
+# which can be installed using fv3net's Makefile. See
+# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310
+
+# Some data has been put to coldline storage making it expensive
+# to access. In order to run processing on that data, set this
+# variable to true.
+ENABLE_COLDLINE ?= false
+
+RESOLUTION ?= 1deg
+LAYERS ?= 8layer
+
+ROOT_BASE_CLIMSST_1DEG = gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360
+ROOT_BASE_CLIMSST_4DEG = gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90
+ROOT_BASE_AMIP_1DEG = gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360
+ROOT_BASE_AMIP_4DEG = gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90
+ROOT_INTERMEDIATE_CLIMSST_1DEG = gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset
+ROOT_INTERMEDIATE_CLIMSST_4DEG = gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset
+ROOT_INTERMEDIATE_AMIP_1DEG = gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset
+ROOT_INTERMEDIATE_AMIP_4DEG = gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset
+OUTPUT_DIR_CLIMSST_1DEG = /net/nfs/climate/data/2023-08-11-vertically-resolved-1deg-fme-ensemble-dataset
+OUTPUT_DIR_CLIMSST_4DEG = /net/nfs/climate/data/2023-08-11-vertically-resolved-4deg-fme-ensemble-dataset
+OUTPUT_DIR_AMIP_1DEG = /pscratch/sd/b/bhenn/data/2023-11-01-vertically-resolved-1deg-fme-amip-ensemble-dataset
+OUTPUT_DIR_AMIP_4DEG = /pscratch/sd/b/bhenn/data/2023-11-01-vertically-resolved-4deg-fme-amip-ensemble-dataset
+NAME_CLIMSST_1DEG = fv3gfs-ensemble
+NAME_CLIMSST_4DEG = fv3gfs-ensemble-4deg
+NAME_AMIP_1DEG = fv3gfs-AMIP-ensemble
+NAME_AMIP_4DEG = fv3gfs-AMIP-ensemble-4deg
+
+ifeq ($(RESOLUTION),1deg)
+ ROOT_INTERMEDIATE ?= $(ROOT_INTERMEDIATE_CLIMSST_1DEG)
+ ROOT_INTERMEDIATE_AMIP ?= $(ROOT_INTERMEDIATE_AMIP_1DEG)
+ ROOT_BASE ?= $(ROOT_BASE_CLIMSST_1DEG)
+ ROOT_AMIP ?= $(ROOT_BASE_AMIP_1DEG)
+ OUTPUT_DIR ?= $(OUTPUT_DIR_CLIMSST_1DEG)
+ OUTPUT_DIR_AMIP ?= $(OUTPUT_DIR_AMIP_1DEG)
+ NAME ?= $(NAME_CLIMSST_1DEG)
+ NAME_AMIP ?= $(NAME_AMIP_1DEG)
+else ifeq ($(RESOLUTION),4deg)
+ ROOT_INTERMEDIATE ?= $(ROOT_INTERMEDIATE_CLIMSST_4DEG)
+ ROOT_INTERMEDIATE_AMIP ?= $(ROOT_INTERMEDIATE_AMIP_4DEG)
+ ROOT_BASE ?= $(ROOT_BASE_CLIMSST_4DEG)
+ ROOT_AMIP ?= $(ROOT_BASE_AMIP_4DEG)
+ OUTPUT_DIR ?= $(OUTPUT_DIR_CLIMSST_4DEG)
+ OUTPUT_DIR_AMIP ?= $(OUTPUT_DIR_AMIP_4DEG)
+ NAME ?= $(NAME_CLIMSST_4DEG)
+ NAME_AMIP ?= $(NAME_AMIP_4DEG)
+endif
+
+# the netCDF generation step is done locally usually, so user sets input and output directories
+NC_INPUT ?=
+NC_OUTPUT ?=
+
+SHIELD_RES ?= c96
+
+.PHONY: shield_AMIP_dataset
+shield_AMIP_dataset:
+ ./compute_dataset.sh \
+ --config configs/shield-amip-ensemble-$(SHIELD_RES)-$(RESOLUTION)-8layer.yaml
+
+.PHONY: shield_AMIP_monthly_netcdfs
+shield_AMIP_monthly_netcdfs:
+ ./convert_to_monthly_netcdf_fv3gfs.sh \
+ --input-url $(NC_INPUT) \
+ --n-ic 2 \
+ --output-dir $(NC_OUTPUT) \
+ --start-date 1940-01-01 \
+ --end-date 2021-12-31
+
+.PHONY: shield_c24_4deg_climSST_dataset
+shield_c24_4deg_climSST_dataset:
+ ./compute_dataset.sh --config configs/shield-c24-ensemble-4deg-8layer.yaml
+
+.PHONY: shield_c24_4deg_climSST_monthly_netcdfs
+shield_c24_4deg_climSST_monthly_netcdfs:
+ ./convert_to_monthly_netcdf_fv3gfs.sh \
+ --input-url gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset \
+ --n-ic 21 \
+ --output-dir $(NC_OUTPUT) \
+ --start-date 2021-01-01 \
+ --end-date 2030-12-31
+
+shield_c24_4deg_climSST_stats_beaker_dataset:
+ ./compute_stats.sh --config configs/shield-c24-ensemble-4deg-8layer.yaml
+
+.PHONY: shield_c96_4deg_climSST_dataset
+shield_c96_4deg_climSST_dataset:
+ ./compute_dataset.sh --config configs/shield-c96-4deg-8layer.yaml
+
+.PHONY: shield_c96_4deg_climSST_monthly_netcdfs
+shield_c96_4deg_climSST_monthly_netcdfs:
+ ./convert_to_monthly_netcdf_fv3gfs.sh \
+ --input-url gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset \
+ --n-ic 1 \
+ --output-dir $(NC_OUTPUT) \
+ --start-date 2035-01-01 \
+ --end-date 2060-12-31
+
+shield_c96_4deg_climSST_stats_beaker_dataset:
+ ./compute_stats.sh --config configs/shield-c96-4deg-8layer.yaml
+
+.PHONY: fv3gfs_1deg_climSST_dataset
+fv3gfs_1deg_climSST_dataset:
+ $(MAKE) fv3gfs_climSST_dataset RESOLUTION=1deg LAYERS=8layer
+
+.PHONY: fv3gfs_climSST_dataset
+fv3gfs_climSST_dataset:
+ ./compute_dataset.sh --config configs/fv3gfs-ensemble-$(RESOLUTION)-$(LAYERS).yaml
+
+.PHONY: fv3gfs_1deg_climSST_monthly_netcdfs
+fv3gfs_1deg_climSST_monthly_netcdfs:
+ $(MAKE) fv3gfs_climSST_monthly_netcdfs RESOLUTION=1deg
+
+.PHONY: enable_coldline_check
+enable_coldline_check:
+ @if [ "$(ENABLE_COLDLINE)" != "true" ]; then \
+ echo "Processing target is deprecated due to coldlined data" \
+ echo "to run, ENABLE_COLDLINE must be set to true. Exiting."; \
+ exit 1; \
+ fi
+
+.PHONY: fv3gfs_climSST_monthly_netcdfs
+fv3gfs_climSST_monthly_netcdfs:
+ ./convert_to_monthly_netcdf_fv3gfs.sh \
+ --input-url $(ROOT_INTERMEDIATE) \
+ --n-ic 11 \
+ --output-dir $(OUTPUT_DIR) \
+ --start-date 2021-01-01 \
+ --end-date 2030-12-31
+
+.PHONY: fv3gfs_1deg_climSST_stats_beaker_dataset
+fv3gfs_1deg_climSST_stats_beaker_dataset:
+ $(MAKE) fv3gfs_climSST_stats_beaker_dataset RESOLUTION=1deg
+
+fv3gfs_climSST_stats_beaker_dataset: enable_coldline_check
+ ./compute_stats.sh --config configs/fv3gfs-ensemble-$(RESOLUTION)-$(LAYERS).yaml
+
+# This took around ~10 hours to complete. The roundtrip_filter adds significant computational time.
+# If we plan to do this regularly, paralellizing the script across time using xpartition to
+# launch jobs for different time chunks would probably be a good idea.
+fv3gfs_climSST_c48_baseline_dataset: enable_coldline_check
+ ./compute_dataset.sh --config configs/fv3gfs-c48-ensemble-1deg-8layer.yaml
+
+.PHONY: fv3gfs_climSST_c48_baseline_monthly_netcdfs
+fv3gfs_climSST_c48_baseline_monthly_netcdfs:
+ python convert_to_monthly_netcdf.py \
+ --prepend-nans \
+ gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset/ic_0011.zarr \
+ /net/nfs/climate/data/2023-09-12-vertically-resolved-1deg-fme-c48-baseline-dataset-truncated-065/ic_0011 \
+ --start-date 2021-01-01 \
+ --end-date 2030-12-31
+
+.PHONY: fv3gfs_1deg_AMIP_dataset
+fv3gfs_1deg_AMIP_dataset:
+ $(MAKE) fv3gfs_AMIP_dataset RESOLUTION=1deg LAYERS=8layer
+
+fv3gfs_AMIP_dataset: enable_coldline_check
+ ./compute_dataset.sh --config configs/fv3gfs-amip-ensemble-$(RESOLUTION)-$(LAYERS).yaml
+
+.PHONY: fv3gfs_1deg_AMIP_monthly_netcdfs
+fv3gfs_1deg_AMIP_monthly_netcdfs:
+ $(MAKE) fv3gfs_AMIP_monthly_netcdfs RESOLUTION=1deg
+
+.PHONY: fv3gfs_AMIP_monthly_netcdfs
+fv3gfs_AMIP_monthly_netcdfs:
+ ./convert_to_monthly_netcdf_fv3gfs.sh \
+ --input-url $(ROOT_INTERMEDIATE_AMIP) \
+ --n-ic 4 \
+ --output-dir $(OUTPUT_DIR_AMIP) \
+ --start-date 1990-01-01 \
+ --end-date 2019-12-31
+
+.PHONY: fv3gfs_1deg_AMIP_stats_beaker_dataset
+fv3gfs_1deg_AMIP_stats_beaker_dataset:
+ $(MAKE) fv3gfs_AMIP_stats_beaker_dataset RESOLUTION=1deg
+
+fv3gfs_AMIP_stats_beaker_dataset: enable_coldline_check
+ ./compute_stats.sh --config configs/fv3gfs-amip-ensemble-$(RESOLUTION)-$(LAYERS).yaml
+
+# TODO: Add AMIP baseline C48 dataset processing when available
+
+.PHONY: shield_som_spin_up_c96_dataset
+shield_som_c96_spin_up_dataset:
+ ./compute_dataset.sh --config configs/shield-som-spin-up-c96-1deg-$(LAYERS).yaml
+
+.PHONY: shield_som_ensemble_c96_dataset
+shield_som_ensemble_c96_dataset:
+ ./compute_dataset.sh --config configs/shield-som-ensemble-c96-$(RESOLUTION)-$(LAYERS).yaml
+
+.PHONY: shield_som_abrupt_co2_increase_c96_dataset
+shield_som_abrupt_co2_increase_c96_dataset:
+ ./compute_dataset.sh --config configs/shield-som-abrupt-co2-increase-c96-$(RESOLUTION)-$(LAYERS).yaml
+
+.PHONY: shield_som_increasing_co2_c96_dataset
+shield_som_c96_increasing_co2_dataset:
+ ./compute_dataset.sh --config configs/shield-som-increasing-co2-c96-$(RESOLUTION)-$(LAYERS).yaml
+
+.PHONY: shield_som_c96_radiation_multi_call_dataset
+shield_som_c96_radiation_multi_call_dataset:
+ ./compute_dataset.sh --config configs/shield-som-radiation-multi-call-c96-1deg-$(LAYERS).yaml
+
+.PHONY: shield_som_c24_dataset
+shield_som_c24_dataset:
+ ./compute_dataset.sh --config configs/shield-som-c24-4deg-$(LAYERS).yaml
+
+.PHONY: shield_som_c24_tuned_cdmbgwd_dataset
+shield_som_c24_tuned_cdmbgwd_dataset:
+ ./compute_dataset.sh --config configs/shield-som-c24-tuned-cdmbgwd-4deg-$(LAYERS).yaml
+
+# In total (not including full_field stats), this took 7 hours 15 minutes to
+# complete using a single Perlmutter CPU node.
+
+.PHONY: e3smv2_1deg_climSST_dataset
+e3smv2_1deg_climSST_dataset:
+ sbatch -J 2024-07-10-e3smv2-1deg-testing generate_datasets_e3smv2.sh \
+ --input-dir /global/cfs/cdirs/m4492/rebassoo/e3sm_test/post/atm/180x360_gaussian/ts \
+ --config configs/e3sm-1deg-8layer.yaml \
+ --zarr /global/cfs/cdirs/m4492/fme-preprocess/zarr/2024-07-10-e3smv2-1deg-testing.zarr \
+ --output-dir /global/cfs/cdirs/m4492/fme-preprocess/2024-07-10-e3smv2-1deg-testing
+
+.PHONY: era5_1deg_stats_beaker_dataset
+era5_1deg_stats_beaker_dataset:
+ ./compute_stats.sh --config configs/era5-1deg-8layer-1940-2022.yaml
+
+.PHONY: era5_1deg_16layer_stats_beaker_dataset
+era5_1deg_16layer_stats_beaker_dataset:
+ ./compute_stats.sh --config configs/era5-1deg-16layer-1940-2022.yaml
+
+.PHONY: compute_cm4_trial_run_atmosphere_dataset
+compute_cm4_trial_run_atmosphere_dataset:
+ python -u compute_dataset.py \
+ --config configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml \
+ --run-directory gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/regridded-zarrs/gaussian_grid_180_by_360/trial-run \
+ --output-store gs://vcm-ml-intermediate/2024-09-20-cm4-1deg-8layer-trial-run.zarr
+
+.PHONY: compute_cm4_trial_run_atmosphere_stats
+compute_cm4_trial_run_atmosphere_stats:
+ python -u get_stats.py configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml 0
+
+.PHONY: build_docker_image
+build_docker_image:
+ docker build -f beakerpy.Dockerfile -t us.gcr.io/vcm-ml/beaker-py .
+
+.PHONY: push_docker_image
+push_docker_image: build_docker_image
+ docker push us.gcr.io/vcm-ml/beaker-py
diff --git a/scripts/data_process/README.md b/scripts/data_process/README.md
index 40dc8ff..4c1706f 100644
--- a/scripts/data_process/README.md
+++ b/scripts/data_process/README.md
@@ -1,3 +1,28 @@
# Data processing for full model emulation training
-This directory contains scripts for generating various datasets needed for FME training
\ No newline at end of file
+This directory contains scripts for generating various datasets needed for FME training, including the FV3GFS primary, baseline, and stats datasets.
+
+It also contains scripts for generating E3SM training data.
+
+The first step in the process to create intermediate datasets (e.g. `make fv3gfs_AMIP_dataset`) uses argo, and can be run on your Google VM.
+See the vcm-workflow-control repo for instructions on how to install and run argo.
+
+The second step, which produces monthly netCDF files locally (e.g. `make fv3gfs_AMIP_monthly_netcdfs`), can be run on cirrascale in an interactive session.
+To create an interactive session, run the following command from the `scripts/data_process` directory:
+
+```
+beaker session create --budget ai2/climate --image beaker://jeremym/fme-2bc0033e --gpus 0 --mount hostPath:///net/nfs/climate=/net/nfs/climate --mount hostpath://$(pwd)=/full-model --workdir /full-model/scripts/data_process
+```
+
+Doing so will require that your current working directory is a mountable path (e.g. something in /data).
+If you'd like to write to a different directory than /net/nfs/climate, you can mount that path instead.
+
+Once inside the image, you will need to authorize access to GCS by running `gcloud auth application-default login` and following the instructions, including to run `gcloud config set project vcm-ml` afterwards.
+
+You can then produce the monthly netCDFs in a target directory by modifying the `OUTPUT_DIR` or `OUTPUT_DIR_AMIP` variable in the make command below.
+
+```
+make fv3gfs_AMIP_monthly_netcdfs RESOLUTION=4deg OUTPUT_DIR_AMIP=/data/shared/2023-12-20-vertically-resolved-4deg-fme-amip-ensemble-dataset
+```
+
+The stats dataset creation step (e.g. `make fv3gfs_AMIP_stats_beaker_dataset`) must be run in the fme conda environment (created by `make create_environment` at the top level of this repo), and additionally requires the beaker client is installed ([install instructions](https://beaker-docs.apps.allenai.org/start/install.html)).
diff --git a/scripts/data_process/beakerpy.Dockerfile b/scripts/data_process/beakerpy.Dockerfile
new file mode 100644
index 0000000..92e7aa7
--- /dev/null
+++ b/scripts/data_process/beakerpy.Dockerfile
@@ -0,0 +1,3 @@
+from python:3.10-slim
+
+RUN pip install beaker-py==1.30.0 click dacite fsspec==2024.6.1 gcsfs==2024.6.1
diff --git a/scripts/data_process/combine_stats.py b/scripts/data_process/combine_stats.py
index c398cb7..d86d565 100644
--- a/scripts/data_process/combine_stats.py
+++ b/scripts/data_process/combine_stats.py
@@ -36,7 +36,7 @@ class Config:
def add_history_attrs(ds, config_filename: str, stats_output_dir: str):
ds.attrs["history"] = (
- "Created by ace/fv3gfs_data_process/combine_stats.py from "
+ "Created by full-model/fv3gfs_data_process/combine_stats.py from "
f"configuration file {config_filename} using inputs at {stats_output_dir}."
)
diff --git a/scripts/data_process/compute_dataset.py b/scripts/data_process/compute_dataset.py
index 4a948a4..7773eb3 100755
--- a/scripts/data_process/compute_dataset.py
+++ b/scripts/data_process/compute_dataset.py
@@ -10,20 +10,25 @@
# for a workflow which parallelizes this script across the 11-member ensemble and runs
# it on our GKE cluster.
+import abc
import dataclasses
import os
-from typing import List, Mapping, MutableMapping, Optional, Sequence, Tuple
+import sys
+from typing import Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
import click
import dacite
import fsspec
import numpy as np
import xarray as xr
-import xpartition # noqa: 401
+import xpartition # noqa: F401
import xtorch_harmonics
import yaml
from dask.diagnostics import ProgressBar
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from get_stats import StatsConfig
+
# constants are defined as in FV3GFS model
# https://github.com/ai2cm/fv3gfs-fortran/blob/master/FMS/constants/constants.F90
LATENT_HEAT_OF_VAPORIZATION = 2.5e6 # J/kg
@@ -41,6 +46,7 @@ class StandardNameMapping:
longitude_dim: str = "grid_xt"
latitude_dim: str = "grid_yt"
vertical_dim: str = "pfull"
+ vertical_interface_dim: str = "phalf"
time_dim: str = "time"
surface_pressure: str = "PRESsfc"
latent_heat_flux: str = "LHTFLsfc"
@@ -56,6 +62,11 @@ class StandardNameMapping:
snow_mixing_ratio: str = "snow_mixing_ratio"
northward_wind: str = "northward_wind"
eastward_wind: str = "eastward_wind"
+ surface_evaporation_rate: str = "surface_evaporation_rate"
+ land_fraction: str = "land_fraction"
+ ocean_fraction: str = "ocean_fraction"
+ sea_ice_fraction: str = "sea_ice_fraction"
+ hybrid_level_coeffs: List[str] = dataclasses.field(default_factory=list)
def __post_init__(self):
self.horizontal_dims: List[str] = [self.longitude_dim, self.latitude_dim]
@@ -65,15 +76,6 @@ def __post_init__(self):
self.pwat_tendency = f"tendency_of_{self.total_water_path}"
self.time_derivative_names = [self.total_water_path]
- self.water_species: List[str] = [
- self.specific_humidity,
- self.cloud_water_mixing_ratio,
- self.cloud_ice_mixing_ratio,
- self.graupel_mixing_ratio,
- self.rain_mixing_ratio,
- self.snow_mixing_ratio,
- ]
-
self.vertically_resolved: List[str] = [
self.specific_total_water,
self.air_temperature,
@@ -85,17 +87,60 @@ def __post_init__(self):
self.dropped_variables: List[str] = (
self.water_species
+ self.vertically_resolved
- + [self.pressure_thickness, self.vertical_dim, self.precipitable_water_path]
+ + [self.pressure_thickness, self.vertical_dim]
)
+ if self.precipitable_water_path.lower() != "none":
+ self.dropped_variables.append(self.precipitable_water_path)
+
+ @property
+ def water_species(self) -> List[str]:
+ return [
+ item
+ for item in [
+ self.specific_humidity,
+ self.cloud_water_mixing_ratio,
+ self.cloud_ice_mixing_ratio,
+ self.graupel_mixing_ratio,
+ self.rain_mixing_ratio,
+ self.snow_mixing_ratio,
+ ]
+ if item.lower() != "none"
+ ]
+
+
+@dataclasses.dataclass
+class DLWPNameMapping(StandardNameMapping):
+ longitude_dim: str = "longitude"
+ latitude_dim: str = "latitude"
+ face_dim: str = "face"
+ width_dim: str = "width"
+ height_dim: str = "height"
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.horizontal_dims: List[str] = [
+ self.face_dim,
+ self.width_dim,
+ self.height_dim,
+ ]
+ self.lat_lon_dims: List[str] = [self.latitude_dim, self.longitude_dim]
@dataclasses.dataclass
-class ChunkingConfig:
+class _ChunkingConfig(abc.ABC):
time_dim: int = 160
+
+ @abc.abstractmethod
+ def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]: ...
+
+
+@dataclasses.dataclass
+class ChunkingConfig(_ChunkingConfig):
latitude_dim: int = -1
longitude_dim: int = -1
- def get_chunks(self, standard_names: StandardNameMapping):
+ def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]:
return {
standard_names.time_dim: self.time_dim,
standard_names.longitude_dim: self.longitude_dim,
@@ -103,11 +148,32 @@ def get_chunks(self, standard_names: StandardNameMapping):
}
+@dataclasses.dataclass
+class DLWPChunkingConfig(_ChunkingConfig):
+ face_dim: int = -1
+ width_dim: int = -1
+ height_dim: int = -1
+
+ def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]:
+ dlwp_names = standard_names
+ if not isinstance(dlwp_names, DLWPNameMapping):
+ raise TypeError(
+ "Expected DLWPChunkingConfig to be passed type of DLWPNameMapping."
+ )
+ chunks = {
+ dlwp_names.time_dim: self.time_dim,
+ dlwp_names.face_dim: self.face_dim,
+ dlwp_names.width_dim: self.width_dim,
+ dlwp_names.height_dim: self.height_dim,
+ }
+ return chunks
+
+
@dataclasses.dataclass
class DatasetComputationConfig:
"""Configuration of computation details for an FME reference dataset.
- Attributes:
+ Parameters:
reference_vertical_coordinate_file: path to netCDF file containing
vertical coordinate definition for the reference simulation.
vertical_coarsening_indices: list of tuples defining the ranges of
@@ -123,6 +189,8 @@ class DatasetComputationConfig:
names of variables in the dataset.
chunking: (optional) mapping of standard dimension names to desired
output chunk sizes
+ time_invariant_dir: (optional) path to directory containing time-invariant data
+ This option is used for E3SMv2 dataset.
"""
reference_vertical_coordinate_file: str
@@ -131,24 +199,31 @@ class DatasetComputationConfig:
n_split: int = 65
renaming: Mapping[str, str] = dataclasses.field(default_factory=dict)
roundtrip_fraction_kept: Optional[float] = None
- standard_names: StandardNameMapping = StandardNameMapping()
- chunking: ChunkingConfig = ChunkingConfig()
+ standard_names: Union[StandardNameMapping, DLWPNameMapping] = dataclasses.field(
+ default_factory=StandardNameMapping
+ )
+ chunking: Union[ChunkingConfig, DLWPChunkingConfig] = dataclasses.field(
+ default_factory=ChunkingConfig
+ )
+ time_invariant_dir: Optional[str] = None
@dataclasses.dataclass
class DatasetConfig:
"""Dataset provenance for a set of reference simulations.
- Attributes:
+ Parameters:
runs: mapping of short names to full paths of reference datasets.
output_directory: path to place output of computation script.
dataset_computation: configuration details for dataset
computation.
+ stats_config: configuration to retrieve statistics dataset
"""
runs: Mapping[str, str]
data_output_directory: str
dataset_computation: DatasetComputationConfig
+ stats: StatsConfig
@classmethod
def from_file(cls, path: str) -> "DatasetConfig":
@@ -156,7 +231,7 @@ def from_file(cls, path: str) -> "DatasetConfig":
data = yaml.safe_load(file)
return dacite.from_dict(
- data_class=cls, data=data, config=dacite.Config(cast=[tuple])
+ data_class=cls, data=data, config=dacite.Config(cast=[tuple], strict=True)
)
@@ -222,16 +297,90 @@ def get_coarse_ak_bk(
return xr.Dataset(data)
+def compute_ocean_fraction(
+ ds: xr.Dataset,
+ output_name: str,
+ land_fraction_name: str,
+ sea_ice_fraction_name: str,
+) -> xr.Dataset:
+ """Compute latent heat flux, if needed."""
+ if output_name in ds.data_vars:
+ # if ocean_fraction is already computed, assume that NaNs have been handled
+ return ds
+ ds[sea_ice_fraction_name] = ds[sea_ice_fraction_name].fillna(0.0)
+ ocean_fraction = 1 - ds[sea_ice_fraction_name] - ds[land_fraction_name]
+ negative_ocean = xr.where(ocean_fraction < 0, ocean_fraction, 0)
+ ocean_fraction -= negative_ocean
+ ds["sea_ice_fraction"] += negative_ocean
+ ocean_fraction.attrs["units"] = "unitless"
+ ocean_fraction.attrs["long_name"] = "fraction of grid cell area occupied by ocean"
+ return ds.assign({output_name: ocean_fraction})
+
+
+def compute_latent_heat_flux(
+ ds: xr.Dataset,
+ output_name: str,
+ evaporation_name: Optional[str] = None,
+) -> xr.Dataset:
+ """Compute latent heat flux, if needed."""
+ if output_name in ds.data_vars:
+ return ds
+ assert (
+ evaporation_name is not None
+ ), f"{output_name} not found in ds, evaporation_name must be provided."
+ latent_heat_flux = ds[evaporation_name] * LATENT_HEAT_OF_VAPORIZATION
+ latent_heat_flux.attrs["units"] = "W/m^2"
+ latent_heat_flux.attrs["long_name"] = "Latent heat flux"
+ return ds.assign({output_name: latent_heat_flux}).drop(evaporation_name)
+
+
def compute_specific_total_water(
ds: xr.Dataset, water_condensate_names: Sequence[str], output_name: str
) -> xr.Dataset:
"""Compute specific total water from individual water species."""
- specific_total_water = sum([ds[name] for name in water_condensate_names])
+ specific_total_water: xr.DataArray = sum(
+ [ds[name] for name in water_condensate_names]
+ )
specific_total_water.attrs["units"] = "kg/kg"
specific_total_water.attrs["long_name"] = output_name.replace("_", " ")
return ds.assign({output_name: specific_total_water})
+def compute_pressure_thickness(
+ ds: xr.Dataset,
+ vertical_coordinate_file: str,
+ vertical_dim_name: str,
+ surface_pressure_name: str,
+ output_name: str,
+ z_dim: str = "xaxis_1",
+):
+ if output_name in ds.data_vars:
+ return ds
+
+ with fsspec.open(vertical_coordinate_file) as f:
+ vertical_coordinate = xr.open_dataset(f).load()
+ # squeeze out the singleton time dimension
+ vertical_coord = vertical_coordinate.squeeze(drop=True)
+
+ sfc_pressure = ds[surface_pressure_name].expand_dims(
+ {z_dim: vertical_coord[z_dim]}, axis=3
+ )
+ phalf = sfc_pressure * vertical_coord["bk"] + vertical_coord["ak"]
+
+ thickness = (
+ phalf.diff(dim=z_dim)
+ .rename({z_dim: vertical_dim_name})
+ .rename(output_name)
+ .assign_coords(
+ {vertical_dim_name: (vertical_dim_name, ds[vertical_dim_name].values)}
+ )
+ )
+
+ thickness.attrs["units"] = "Pa"
+ thickness.attrs["long_name"] = output_name.replace("_", " ")
+ return ds.assign({output_name: thickness})
+
+
def compute_vertical_coarsening(
ds: xr.Dataset,
vertically_resolved_names: Sequence[str],
@@ -325,8 +474,12 @@ def assert_global_dry_air_mass_conservation(
column_dry_air_mass = (
ds[surface_pressure_name] - ds[total_water_path_name] * GRAVITY
)
- weights = np.cos(np.deg2rad(ds[latitude_dim]))
- global_dry_air_mass = column_dry_air_mass.weighted(weights).mean(dim=dims)
+ if latitude_dim in dims:
+ weights = np.cos(np.deg2rad(ds[latitude_dim]))
+ global_dry_air_mass = column_dry_air_mass.weighted(weights).mean(dim=dims)
+ else:
+ global_dry_air_mass = column_dry_air_mass.mean(dim=dims)
+
global_dry_air_mass_tendency = global_dry_air_mass.diff(time_dim)
print("Mean absolute global dry air pressure tendency [Pa]:")
print(np.abs(global_dry_air_mass_tendency).mean().values)
@@ -385,11 +538,29 @@ def construct_lazy_dataset(
lon_dim=standard_names.longitude_dim,
fraction_modes_kept=config.roundtrip_fraction_kept,
)
+ ds = compute_ocean_fraction(
+ ds,
+ output_name=standard_names.ocean_fraction,
+ land_fraction_name=standard_names.land_fraction,
+ sea_ice_fraction_name=standard_names.sea_ice_fraction,
+ )
+ ds = compute_latent_heat_flux(
+ ds,
+ output_name=standard_names.latent_heat_flux,
+ evaporation_name=standard_names.surface_evaporation_rate,
+ )
ds = compute_specific_total_water(
ds,
water_condensate_names=standard_names.water_species,
output_name=standard_names.specific_total_water,
)
+ ds = compute_pressure_thickness(
+ ds,
+ vertical_coordinate_file=config.reference_vertical_coordinate_file,
+ vertical_dim_name=standard_names.vertical_dim,
+ surface_pressure_name=standard_names.surface_pressure,
+ output_name=standard_names.pressure_thickness,
+ )
ds = compute_vertical_coarsening(
ds,
vertically_resolved_names=standard_names.vertically_resolved,
@@ -424,7 +595,7 @@ def construct_lazy_dataset(
chunks = config.chunking.get_chunks(standard_names)
ds = ds.chunk(chunks)
ds.attrs["history"] = (
- "Dataset computed by ace/scripts/data_process"
+ "Dataset computed by full-model/scripts/data_process"
"/compute_dataset_fv3gfs.py"
f" script, using following input zarrs: {urls.values()}."
)
@@ -473,6 +644,7 @@ def main(
surface_pressure_name=standard_names.surface_pressure,
total_water_path_name=standard_names.total_water_path,
latitude_dim=standard_names.latitude_dim,
+ time_dim=standard_names.time_dim,
)
assert_global_moisture_conservation(
ds,
@@ -482,6 +654,7 @@ def main(
latent_heat_flux_name=standard_names.latent_heat_flux,
latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION,
precip_rate_name=standard_names.precip_rate,
+ time_dim=standard_names.time_dim,
)
ds = ds.drop(standard_names.dropped_variables)
print(f"Output dataset size is {ds.nbytes / 1e9} GB")
diff --git a/scripts/data_process/compute_dataset.sh b/scripts/data_process/compute_dataset.sh
new file mode 100755
index 0000000..619c629
--- /dev/null
+++ b/scripts/data_process/compute_dataset.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+set -e
+
+COMPUTE_DATASET=true
+
+while [[ "$#" -gt 0 ]]
+do case $1 in
+ --config) CONFIG="$2"
+ shift;;
+ --stats-only) COMPUTE_DATASET=false;;
+ *) echo "Unknown parameter passed: $1"
+ exit 1;;
+esac
+shift
+done
+
+if [[ -z "${CONFIG}" ]]
+then
+ echo "Option --config missing"
+ exit 1;
+fi
+
+names=($(yq -r '.runs | to_entries[].key' ${CONFIG}))
+run_directories=($(yq -r '.runs | to_entries[].value' ${CONFIG}))
+output_directory=$(yq -r '.data_output_directory' ${CONFIG})
+runs_count=$(yq -r '.runs | length' ${CONFIG})
+runs_count_minus_one=$(($runs_count - 1))
+
+# Capture the output of the argo submit command
+output=$(argo submit compute_dataset_argo_workflow.yaml \
+ -p compute_dataset=${COMPUTE_DATASET} \
+ -p python_script="$(< compute_dataset.py)" \
+ -p get_stats_script="$(< get_stats.py)" \
+ -p combine_stats_script="$(< combine_stats.py)" \
+ -p upload_stats_script="$(< upload_stats.py)" \
+ -p config="$(< ${CONFIG})" \
+ -p names="${names[*]}" \
+ -p run_directories="${run_directories[*]}" \
+ -p output_directory="${output_directory}" \
+ -p runs_count_minus_one=${runs_count_minus_one})
+
+# Extract the job name from the output
+job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}')
+
+# Print the job name
+echo "Argo job submitted: $job_name"
diff --git a/scripts/data_process/compute_dataset_argo_workflow.yaml b/scripts/data_process/compute_dataset_argo_workflow.yaml
new file mode 100644
index 0000000..9c05fb0
--- /dev/null
+++ b/scripts/data_process/compute_dataset_argo_workflow.yaml
@@ -0,0 +1,282 @@
+apiVersion: argoproj.io/v1alpha1
+kind: Workflow
+metadata:
+ generateName: compute-fme-dataset-ensemble-
+spec:
+ entrypoint: compute-fme-dataset-ensemble
+ volumes:
+ - name: gcp-key-secret
+ secret:
+ defaultMode: 420
+ secretName: gcp-key
+ arguments:
+ parameters:
+ - name: python_script
+ - name: get_stats_script
+ - name: combine_stats_script
+ - name: upload_stats_script
+ - name: config
+ - name: names
+ - name: run_directories
+ - name: output_directory
+ - name: runs_count_minus_one
+ - name: compute_dataset
+ value: "true" # default value, can be overridden
+ templates:
+ - name: compute-fme-dataset-ensemble
+ steps:
+ - - name: compute-fme-dataset-individual
+ when: "{{workflow.parameters.compute_dataset}} == true"
+ template: compute-fme-dataset-individual
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.python_script}}"
+ - name: stats_script
+ value: "{{workflow.parameters.get_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - name: names
+ value: "{{workflow.parameters.names}}"
+ - name: run_directories
+ value: "{{workflow.parameters.run_directories}}"
+ - name: output_directory
+ value: "{{workflow.parameters.output_directory}}"
+ - name: run
+ value: "{{item}}"
+ withSequence:
+ start: "0"
+ end: "{{workflow.parameters.runs_count_minus_one}}"
+ - - name: get-stats
+ template: get-stats
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.get_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - name: run
+ value: "{{item}}"
+ withSequence:
+ start: "0"
+ end: "{{workflow.parameters.runs_count_minus_one}}"
+ - - name: combine-stats
+ template: combine-stats
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.combine_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - - name: upload-beaker-stats
+ template: upload-beaker-stats
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.upload_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - name: compute-fme-dataset-individual
+ tolerations:
+ - effect: NoSchedule
+ key: dedicated
+ value: med-sim-pool
+ inputs:
+ parameters:
+ - name: python_script
+ - name: stats_script
+ - name: config
+ - name: names
+ - name: run_directories
+ - name: output_directory
+ - name: run
+ container:
+ image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2
+ resources:
+ limits:
+ cpu: "8000m"
+ memory: "48Gi"
+ requests:
+ cpu: "7500m"
+ memory: "48Gi"
+ command: ["bash", "-c", "-e"]
+ args:
+ - |
+ cat << EOF > script.py
+ {{inputs.parameters.python_script}}
+ EOF
+
+ cat << EOF > get_stats.py
+ {{inputs.parameters.stats_script}}
+ EOF
+
+ cat << EOF > config.yaml
+ {{inputs.parameters.config}}
+ EOF
+
+ run={{inputs.parameters.run}}
+ names=({{inputs.parameters.names}})
+ run_directories=({{inputs.parameters.run_directories}})
+ output_directory={{inputs.parameters.output_directory}}
+
+ name="${names[${run}]}"
+ run_directory="${run_directories[${run}]}"
+
+ output_store=${output_directory}/${name}.zarr
+
+ python script.py --config config.yaml \
+ --run-directory ${run_directory} \
+ --output-store ${output_store}
+ env:
+ - name: GOOGLE_APPLICATION_CREDENTIALS
+ value: /secret/gcp-credentials/key.json
+ - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
+ value: /secret/gcp-credentials/key.json
+ volumeMounts:
+ - mountPath: /secret/gcp-credentials
+ name: gcp-key-secret
+ - name: compute-fme-dataset-ensemble-stats
+ steps:
+ - - name: get-stats
+ template: get-stats
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.get_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - name: run
+ value: "{{item}}"
+ withSequence:
+ start: "0"
+ end: "{{workflow.parameters.runs_count_minus_one}}"
+ - - name: combine-stats
+ template: combine-stats
+ arguments:
+ parameters:
+ - name: python_script
+ value: "{{workflow.parameters.combine_stats_script}}"
+ - name: config
+ value: "{{workflow.parameters.config}}"
+ - name: get-stats
+ tolerations:
+ - effect: NoSchedule
+ key: dedicated
+ value: med-sim-pool
+ inputs:
+ parameters:
+ - name: python_script
+ - name: config
+ - name: run
+ container:
+ image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2
+ resources:
+ limits:
+ cpu: "8000m"
+ memory: "27Gi"
+ requests:
+ cpu: "7500m"
+ memory: "27Gi"
+ command: ["bash", "-c", "-e"]
+ args:
+ - |
+ cat << EOF > script.py
+ {{inputs.parameters.python_script}}
+ EOF
+
+ cat << EOF > config.yaml
+ {{inputs.parameters.config}}
+ EOF
+
+ run={{inputs.parameters.run}}
+
+ python script.py config.yaml ${run}
+ env:
+ - name: GOOGLE_APPLICATION_CREDENTIALS
+ value: /secret/gcp-credentials/key.json
+ - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
+ value: /secret/gcp-credentials/key.json
+ volumeMounts:
+ - mountPath: /secret/gcp-credentials
+ name: gcp-key-secret
+ - name: combine-stats
+ tolerations:
+ - effect: NoSchedule
+ key: dedicated
+ value: med-sim-pool
+ inputs:
+ parameters:
+ - name: python_script
+ - name: config
+ container:
+ image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2
+ resources:
+ limits:
+ cpu: "8000m"
+ memory: "27Gi"
+ requests:
+ cpu: "7500m"
+ memory: "27Gi"
+ command: ["bash", "-c", "-e"]
+ args:
+ - |
+ cat << EOF > script.py
+ {{inputs.parameters.python_script}}
+ EOF
+
+ cat << EOF > config.yaml
+ {{inputs.parameters.config}}
+ EOF
+
+ python script.py config.yaml
+ env:
+ - name: GOOGLE_APPLICATION_CREDENTIALS
+ value: /secret/gcp-credentials/key.json
+ - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
+ value: /secret/gcp-credentials/key.json
+ volumeMounts:
+ - mountPath: /secret/gcp-credentials
+ name: gcp-key-secret
+ - name: upload-beaker-stats
+ tolerations:
+ - effect: NoSchedule
+ key: dedicated
+ value: med-sim-pool
+ inputs:
+ parameters:
+ - name: python_script
+ - name: config
+ container:
+ image: us.gcr.io/vcm-ml/beaker-py
+ resources:
+ limits:
+ cpu: "8000m"
+ memory: "27Gi"
+ requests:
+ cpu: "7500m"
+ memory: "27Gi"
+ command: ["bash", "-c", "-e"]
+ args:
+ - |
+ cat << EOF > script.py
+ {{inputs.parameters.python_script}}
+ EOF
+
+ cat << EOF > config.yaml
+ {{inputs.parameters.config}}
+ EOF
+
+ python script.py config.yaml
+ env:
+ - name: GOOGLE_APPLICATION_CREDENTIALS
+ value: /secret/gcp-credentials/key.json
+ - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE
+ value: /secret/gcp-credentials/key.json
+ - name: BEAKER_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: beaker-key-andrep
+ key: BEAKER_USER_KEY
+ volumeMounts:
+ - mountPath: /secret/gcp-credentials
+ name: gcp-key-secret
diff --git a/scripts/data_process/compute_dataset_e3smv2.py b/scripts/data_process/compute_dataset_e3smv2.py
index 675db96..91e3fa4 100755
--- a/scripts/data_process/compute_dataset_e3smv2.py
+++ b/scripts/data_process/compute_dataset_e3smv2.py
@@ -18,8 +18,10 @@
import click
import numpy as np
import xarray as xr
-import xpartition # noqa: 401
-from compute_dataset_fv3gfs import (
+import xpartition # noqa: F401
+from compute_dataset import (
+ DatasetComputationConfig,
+ DatasetConfig,
assert_column_integral_of_moisture_is_conserved,
assert_global_dry_air_mass_conservation,
assert_global_moisture_conservation,
@@ -34,142 +36,17 @@
from xtorch_harmonics import roundtrip_filter
# default paths for input/output; can be changed when calling this script
-INPUT_DIR = "/global/cfs/cdirs/e3sm/golaz/E3SM/fme/20230614.v2.LR.F2010/post/atm/180x360_gaussian/ts" # noqa: 501
-TIME_INVARIANT_INPUT_DIR = "/global/cfs/cdirs/m4331/jpduncan/e3smv2/time_invariant"
-OUTPUT_URL = "/pscratch/sd/j/jpduncan/ai2/zarr/2023-11-22-e3smv2-vertically-resolved-1deg-fme-dataset.zarr" # noqa: 501
-
-# these are subdirs of INPUT_DIR
-INSTANT = "6hourly_instant/1yr"
-MEAN = "6hourly/1yr"
+INPUT_DIR = "/global/cfs/cdirs/e3sm/golaz/E3SM/fme/20230614.v2.LR.F2010/post/atm/180x360_gaussian/ts" # noqa: E501
+OUTPUT_URL = "/pscratch/sd/j/jpduncan/ai2/zarr/2023-11-22-e3smv2-vertically-resolved-1deg-fme-dataset.zarr" # noqa: E501
REFERENCE_PRESSURE = 1e5 # Pa
LIQUID_PRECIP_DENSITY = 1e3 # kg/m**3
LATENT_HEAT_OF_VAPORIZATION = 2.501e6 # J/kg
-# 6-hourly instant
-SURFACE_PRESSURE = "PS"
-SURFACE_TEMPERATURE = "TS"
-AIR_TEMPERATURE = "T"
-EASTWARD_WIND = "U"
-NORTHWARD_WIND = "V"
-MEAN_SEA_LEVEL_PRESSURE = "PSL"
-GEOPOTENTIAL = "Z"
-RELATIVE_HUMIDITY = "RH"
-TOTAL_COLUMN_WATER_VAPOR = "TMQ"
-
-# 6-hourly mean
-TOTAL_PRECIP_RATE = "PRECT" # m/s
-LATENT_HEAT_FLUX = "LHFLX"
-
# derived variable names
-SPECIFIC_TOTAL_WATER = "specific_total_water"
-PRECIPITABLE_WATER_PATH = "precipitable_water_path" # total from E3SMv2 model outputs
-TOTAL_WATER_PATH = "total_water_path" # computed by vertical integration of 3D vars
-PRECIP_RATE = "surface_precipitation_rate"
SURFACE_UP_LONGWAVE_FLUX = "surface_upward_longwave_flux"
SURFACE_UP_SHORTWAVE_FLUX = "surface_upward_shortwave_flux"
TOA_UP_SHORTWAVE_FLUX = "top_of_atmos_upward_shortwave_flux"
-PRESSURE_THICKNESS = "pressure_thickness_of_atmospheric_layer"
-
-# dims
-TIME_DIM = "time"
-HORIZONTAL_DIMS = ["lat", "lon"]
-LATITUDE_DIM = "lat"
-VERTICAL_DIM = "lev"
-VERTICAL_INTERFACE_DIM = "ilev"
-HYBRID_LEVEL_COEFFS = ["hyai", "hybi"]
-
-CHUNKS = {"time": 10, "lat": 180, "lon": 360}
-
-# assumed to be found in INSTANT dir
-FOURCASTNET_VANILLA = {
- EASTWARD_WIND: ["1000", "850", "500"],
- NORTHWARD_WIND: ["1000", "850", "500"],
- AIR_TEMPERATURE: ["850", "500"],
- GEOPOTENTIAL: ["1000", "500", "850", "050"],
- RELATIVE_HUMIDITY: ["850", "500"],
- SURFACE_PRESSURE: [""],
- "TREFHT": [""], # temp at 2m
- MEAN_SEA_LEVEL_PRESSURE: [""],
- TOTAL_COLUMN_WATER_VAPOR: [""],
-}
-
-# the variables / filename prefixes we need from the raw E3SMv2 output
-INPUT_VARIABLE_NAMES = {
- INSTANT: [
- SURFACE_PRESSURE,
- SURFACE_TEMPERATURE,
- AIR_TEMPERATURE,
- EASTWARD_WIND,
- NORTHWARD_WIND,
- "Q",
- "CLDLIQ",
- "CLDICE",
- "RAINQM",
- "SNOWQM",
- TOTAL_COLUMN_WATER_VAPOR,
- "TGCLDLWP",
- "TGCLDIWP",
- "OCNFRAC",
- "LANDFRAC",
- "ICEFRAC",
- ],
- MEAN: [
- TOTAL_PRECIP_RATE,
- LATENT_HEAT_FLUX,
- "SHFLX",
- "FLNS",
- "FLDS",
- "FSNS",
- "FSDS",
- "FSNTOA",
- "SOLIN",
- "FLUT",
- # only for water budget dataset:
- "PRECSC",
- "PRECSL",
- "QFLX",
- ],
- "time_invariant": ["PHIS"],
-}
-
-
-WATER_SPECIES_NAMES = [
- "Q",
- "CLDLIQ",
- "CLDICE",
- "RAINQM",
- "SNOWQM",
-]
-
-VARNAMES_3D = [
- AIR_TEMPERATURE,
- EASTWARD_WIND,
- NORTHWARD_WIND,
-] + WATER_SPECIES_NAMES
-
-PRECIPITABLE_WATER_PATH_NAMES = [TOTAL_COLUMN_WATER_VAPOR, "TGCLDLWP", "TGCLDIWP"]
-
-VERTICALLY_RESOLVED_NAMES = [
- SPECIFIC_TOTAL_WATER,
- AIR_TEMPERATURE,
- NORTHWARD_WIND,
- EASTWARD_WIND,
-]
-
-TIME_DERIVATIVE_NAMES = [TOTAL_WATER_PATH]
-
-# computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2023-06-09-e3smv2-vertical-interface-indices.ipynb # noqa: 501
-VERTICAL_LEVEL_INTERFACES = [
- (0, 19),
- (19, 30),
- (30, 38),
- (38, 44),
- (44, 48),
- (48, 53),
- (53, 61),
- (61, 72),
-]
RAD_FLUX_FORMULAS = {
SURFACE_UP_LONGWAVE_FLUX: (lambda x, y: x + y, "FLNS", "FLDS"),
@@ -177,6 +54,9 @@
TOA_UP_SHORTWAVE_FLUX: (lambda x, y: x - y, "SOLIN", "FSNTOA"),
}
+SURFACE_PRECIPITATION = "surface_precipitation_rate"
+PRECIPITABLE_WATER_PATH_NAMES = ["TMQ", "TGCLDLWP", "TGCLDIWP"]
+
DROP_VARIABLE_NAMES = {
"2D": [ # variables to drop when opening 2D vars
"time_bnds",
@@ -197,46 +77,8 @@
"hyam",
"hybm",
],
- "POST": [ # variables to drop at the end
- AIR_TEMPERATURE,
- EASTWARD_WIND,
- NORTHWARD_WIND,
- SPECIFIC_TOTAL_WATER,
- PRESSURE_THICKNESS,
- TOTAL_PRECIP_RATE,
- PRECIPITABLE_WATER_PATH,
- TOTAL_COLUMN_WATER_VAPOR,
- VERTICAL_DIM,
- VERTICAL_INTERFACE_DIM,
- HYBRID_LEVEL_COEFFS[0],
- HYBRID_LEVEL_COEFFS[1],
- "Q",
- "CLDLIQ",
- "CLDICE",
- "RAINQM",
- "SNOWQM",
- "TGCLDLWP",
- "TGCLDIWP",
- "FLNS",
- "FSNS",
- "FSNTOA",
- "PRECSC",
- "PRECSL",
- "QFLX",
- ],
}
-# dataset of 2D vars for checking water conservation
-WATER_BUDGET_DATASET_VARS = PRECIPITABLE_WATER_PATH_NAMES + [
- SURFACE_PRESSURE,
- PRECIPITABLE_WATER_PATH,
- TOTAL_PRECIP_RATE,
- LATENT_HEAT_FLUX,
- "PRECSC",
- "PRECSL",
- "QFLX",
-]
-
def expand_names_by_level(variables: MutableMapping[str, List[str]]) -> List[str]:
names = []
@@ -246,45 +88,50 @@ def expand_names_by_level(variables: MutableMapping[str, List[str]]) -> List[str
def get_nc_paths(
- base_dir: str, var_names: Optional[List[str]] = None
+ base_dir: str, var_names: Sequence[str]
) -> MutableMapping[str, List[str]]:
- if var_names is None:
- paths = {"time_invariant": list(glob(os.path.join(base_dir, f"*.nc")))}
- else:
- paths = {
- var_name: sorted(list(glob(os.path.join(base_dir, f"{var_name}_*.nc"))))
- for var_name in var_names
- }
+ paths = {
+ var_name: sorted(list(glob(os.path.join(base_dir, f"{var_name}_*.nc"))))
+ for var_name in var_names
+ }
+ return paths
+
+
+def get_time_invariant_nc_paths(
+ base_dir: Optional[str],
+) -> MutableMapping[str, List[str]]:
+ paths = {
+ "time_invariant": list(glob(os.path.join(base_dir, f"*.nc"))) # type: ignore
+ }
return paths
def open_dataset(
dataset_dirs: MutableMapping[str, str],
- time_invariant_dir: str,
- input_variable_names: MutableMapping[str, List[str]] = INPUT_VARIABLE_NAMES,
- varnames_3d: List[str] = VARNAMES_3D,
- drop_variable_names: MutableMapping[str, List[str]] = DROP_VARIABLE_NAMES,
- chunks: MutableMapping[str, int] = CHUNKS,
- vanilla: bool = False,
+ config: DatasetComputationConfig,
) -> xr.Dataset:
"""Open datasets from NetCDF files in directory that match input variable names."""
- if vanilla:
- var_names = expand_names_by_level(FOURCASTNET_VANILLA)
- var_paths = get_nc_paths(dataset_dirs[INSTANT], var_names)
- else:
- var_paths = {}
- for key in dataset_dirs.keys():
- var_paths.update(get_nc_paths(dataset_dirs[key], input_variable_names[key]))
- var_paths.update(get_nc_paths(time_invariant_dir))
+ var_paths: MutableMapping[str, List[str]] = {}
+ input_variable_names = config.variable_sources
+ for key in dataset_dirs.keys():
+ var_paths.update(get_nc_paths(dataset_dirs[key], input_variable_names[key]))
+ var_paths.update(get_time_invariant_nc_paths(config.time_invariant_dir))
print(
f"Opening {len(list(chain.from_iterable(var_paths.values())))} files with "
f"{len(var_paths.keys())} vars..."
)
+ standard_names = config.standard_names
+ varnames_3D = [
+ standard_names.air_temperature,
+ standard_names.eastward_wind,
+ standard_names.northward_wind,
+ ] + standard_names.water_species
+ chunks = config.chunking.get_chunks(config.standard_names)
datasets = {}
start = time.time()
if "time_invariant" in var_paths:
for path in var_paths["time_invariant"]:
- ds = xr.open_dataset(path).drop(drop_variable_names["2D"], errors="ignore")
+ ds = xr.open_dataset(path).drop(DROP_VARIABLE_NAMES["2D"], errors="ignore")
if "time" in ds.coords:
ds = ds.isel(time=0, drop=True)
for varname in input_variable_names["time_invariant"]:
@@ -293,10 +140,10 @@ def open_dataset(
del var_paths["time_invariant"]
for varname, paths in var_paths.items():
var_start = time.time()
- if varname in varnames_3d:
- drop_vars = drop_variable_names["3D"]
+ if varname in varnames_3D:
+ drop_vars = DROP_VARIABLE_NAMES["3D"]
else:
- drop_vars = drop_variable_names["2D"]
+ drop_vars = DROP_VARIABLE_NAMES["2D"]
datasets[varname] = xr.open_mfdataset(
paths,
chunks=chunks,
@@ -311,12 +158,12 @@ def open_dataset(
def compute_pressure_thickness(
ds: xr.Dataset,
- vertical_dim: str = VERTICAL_DIM,
- vertical_interface_dim: str = VERTICAL_INTERFACE_DIM,
- hybrid_level_coeffs: List[str] = HYBRID_LEVEL_COEFFS,
- reference_pressure: float = REFERENCE_PRESSURE,
- surface_pressure: str = SURFACE_PRESSURE,
- output_name: str = PRESSURE_THICKNESS,
+ vertical_dim: str,
+ vertical_interface_dim: str,
+ hybrid_level_coeffs: List[str],
+ reference_pressure: float,
+ surface_pressure: str,
+ output_name: str,
):
hyai, hybi = hybrid_level_coeffs
sfc_pressure = ds[surface_pressure].expand_dims(
@@ -336,9 +183,9 @@ def compute_pressure_thickness(
def compute_coarse_ak_bk(
ds: xr.Dataset,
- interface_indices: Sequence[Tuple[int, int]] = VERTICAL_LEVEL_INTERFACES,
- z_dim="ilev",
- hybrid_level_coeffs: List[str] = HYBRID_LEVEL_COEFFS,
+ interface_indices: Sequence[Tuple[int, int]],
+ z_dim: str,
+ hybrid_level_coeffs: List[str],
reference_pressure: float = REFERENCE_PRESSURE,
):
"""Return dataset with scalar ak and bk coordinates that define coarse
@@ -376,9 +223,9 @@ def compute_coarse_ak_bk(
def compute_surface_precipitation_rate(
ds: xr.Dataset,
- total_precip_rate_name: str = TOTAL_PRECIP_RATE,
+ total_precip_rate_name,
liquid_precip_density: float = LIQUID_PRECIP_DENSITY,
- output_name: str = PRECIP_RATE,
+ output_name: str = SURFACE_PRECIPITATION,
):
precip_mass_flux = ds[total_precip_rate_name] * liquid_precip_density
precip_mass_flux.attrs["units"] = "kg/m2/s"
@@ -388,8 +235,8 @@ def compute_surface_precipitation_rate(
def compute_precipitable_water_path(
ds: xr.Dataset,
- precipitable_water_path_names: List[str] = PRECIPITABLE_WATER_PATH_NAMES,
- output_name: str = PRECIPITABLE_WATER_PATH,
+ output_name: str,
+ precipitable_water_path_names: List[str],
):
water_path = sum([ds[name] for name in precipitable_water_path_names])
water_path.attrs["units"] = "kg/m2"
@@ -412,55 +259,83 @@ def compute_rad_fluxes(
def construct_lazy_dataset(
+ config: DatasetComputationConfig,
dataset_dirs: MutableMapping[str, str],
- time_invariant_dir: str,
- vanilla: bool = False,
- sht_roundtrip: bool = False,
) -> xr.Dataset:
start = time.time()
+ standard_names = config.standard_names
print(f"Opening dataset...")
- ds = open_dataset(dataset_dirs, time_invariant_dir, vanilla=vanilla)
+ ds = open_dataset(dataset_dirs, config)
print(f"Dataset opened in {time.time() - start:.2f} s total.")
print(f"Input dataset size is {ds.nbytes / 1e9} GB")
- if sht_roundtrip:
- ds = roundtrip_filter(ds, lat_dim="lat", lon_dim="lon")
- if not vanilla:
- ds = compute_pressure_thickness(ds)
- ds = compute_rad_fluxes(ds)
- ds = compute_surface_precipitation_rate(ds)
- ds = compute_precipitable_water_path(ds) # only used for conservation check
- ds = compute_specific_total_water(ds, WATER_SPECIES_NAMES, SPECIFIC_TOTAL_WATER)
- ds = compute_vertical_coarsening(
- ds,
- VERTICALLY_RESOLVED_NAMES,
- VERTICAL_LEVEL_INTERFACES,
- VERTICAL_DIM,
- PRESSURE_THICKNESS,
- )
- ds = compute_column_moisture_integral(
+ if config.roundtrip_fraction_kept is not None:
+ ds = roundtrip_filter(
ds,
- SPECIFIC_TOTAL_WATER,
- TOTAL_WATER_PATH,
- PRESSURE_THICKNESS,
- VERTICAL_DIM,
+ lat_dim=standard_names.latitude_dim,
+ lon_dim=standard_names.longitude_dim,
+ fraction_modes_kept=config.roundtrip_fraction_kept,
)
- ds[TOTAL_WATER_PATH].attrs["units"] = "kg/m2" # change to E3SMv2 format
- ds = compute_tendencies(ds, TIME_DERIVATIVE_NAMES, TIME_DIM)
- ds = compute_column_advective_moisture_tendency(
- ds,
- f"tendency_of_{TOTAL_WATER_PATH}",
- LATENT_HEAT_FLUX,
- PRECIP_RATE,
- LATENT_HEAT_OF_VAPORIZATION,
- )
- ak_bk_ds = compute_coarse_ak_bk(ds, VERTICAL_LEVEL_INTERFACES)
- ds = xr.merge([ds, ak_bk_ds])
- ds_dirs = list(dataset_dirs.values())
- else:
- ds_dirs = [dataset_dirs[INSTANT]]
- ds = ds.chunk(CHUNKS).astype(np.float32)
+
+ ds = compute_pressure_thickness(
+ ds,
+ vertical_dim=standard_names.vertical_dim,
+ vertical_interface_dim=standard_names.vertical_interface_dim,
+ hybrid_level_coeffs=standard_names.hybrid_level_coeffs,
+ reference_pressure=REFERENCE_PRESSURE,
+ surface_pressure=standard_names.surface_pressure,
+ output_name=standard_names.pressure_thickness,
+ )
+ ds = compute_rad_fluxes(ds)
+ ds = compute_surface_precipitation_rate(
+ ds,
+ total_precip_rate_name=standard_names.precip_rate,
+ )
+ water_species_name = [
+ item for item in standard_names.water_species if item.lower() != "none"
+ ]
+ ds = compute_specific_total_water(
+ ds,
+ water_condensate_names=water_species_name,
+ output_name=standard_names.specific_total_water,
+ )
+ ds = compute_vertical_coarsening(
+ ds,
+ vertically_resolved_names=standard_names.vertically_resolved,
+ interface_indices=config.vertical_coarsening_indices,
+ dim=standard_names.vertical_dim,
+ pressure_thickness_name=standard_names.pressure_thickness,
+ )
+ ds = compute_column_moisture_integral(
+ ds,
+ input_name=standard_names.specific_total_water,
+ output_name=standard_names.total_water_path,
+ pressure_thickness_name=standard_names.pressure_thickness,
+ dim=standard_names.vertical_dim,
+ )
+ ds = compute_tendencies(
+ ds,
+ time_derivative_names=standard_names.time_derivative_names,
+ dim=standard_names.time_dim,
+ )
+ ds = compute_column_advective_moisture_tendency(
+ ds,
+ pwat_tendency=standard_names.pwat_tendency,
+ latent_heat_flux=standard_names.latent_heat_flux,
+ precip=standard_names.precip_rate,
+ latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION,
+ )
+ ak_bk_ds = compute_coarse_ak_bk(
+ ds,
+ interface_indices=config.vertical_coarsening_indices,
+ z_dim=standard_names.vertical_interface_dim,
+ hybrid_level_coeffs=standard_names.hybrid_level_coeffs,
+ )
+ ds = xr.merge([ds, ak_bk_ds])
+ ds_dirs = list(dataset_dirs.values())
+ chunks = config.chunking.get_chunks(config.standard_names)
+ ds = ds.chunk(chunks).astype(np.float32)
ds.attrs["history"] = (
- "Dataset computed by ace/scripts/e3smv2_data_process"
+ "Dataset computed by full-model/scripts/e3smv2_data_process"
"/compute_dataset_e3smv2.py"
f" script, using inputs from the following directories: {ds_dirs}."
)
@@ -474,94 +349,121 @@ def construct_lazy_dataset(
@click.command()
-@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.")
-@click.option("--subsample", is_flag=True, help="Subsample the data before writing.")
-@click.option("--vanilla", is_flag=True, help="Compute vanilla FourCastNet dataset.")
-@click.option("--check-conservation", is_flag=True, help="Check conservation.")
-@click.option(
- "--water-budget-dataset",
- is_flag=True,
- help="Create a dataset of 2D vars for checking the water budget.",
-)
-@click.option("--sht-roundtrip", is_flag=True, help="SHT roundtrip as a first step.")
+@click.option("--config", help="Path to dataset configuration YAML file.")
@click.option(
"-i",
"--input-dir",
default=INPUT_DIR,
help="Directory in which to find time-varying input ncs.",
)
+@click.option("-o", "--output", default=OUTPUT_URL, help="URL to write output to.")
+@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.")
+@click.option("--subsample", is_flag=True, help="Subsample the data before writing.")
+@click.option("--check-conservation", is_flag=True, help="Check conservation.")
@click.option(
- "-t",
- "--time-invariant-input-dir",
- default=TIME_INVARIANT_INPUT_DIR,
- help="Directory in which to find time-invariant input ncs.",
+ "--water-budget-dataset",
+ is_flag=True,
+ help="Create a dataset of 2D vars for checking the water budget.",
)
-@click.option("-o", "--output", default=OUTPUT_URL, help="URL to write output to.")
-@click.option("--n-split", default=100, help="Number of steps to split job over.")
@click.option("--n-workers", default=4, help="Number of Dask workers.")
def main(
+ config,
+ input_dir,
+ output,
debug,
subsample,
- vanilla,
- sht_roundtrip,
check_conservation,
water_budget_dataset,
- input_dir,
- time_invariant_input_dir,
- output,
- n_split,
n_workers,
):
xr.set_options(keep_attrs=True)
_ = Client(n_workers=n_workers)
-
- dataset_dirs = {
- INSTANT: os.path.join(input_dir, INSTANT),
- MEAN: os.path.join(input_dir, MEAN),
- }
- ds = construct_lazy_dataset(
- dataset_dirs, time_invariant_input_dir, vanilla, sht_roundtrip
- )
+ config = DatasetConfig.from_file(config).dataset_computation
+ standard_names = config.standard_names
+ dataset_dirs = {}
+ for key in config.variable_sources.keys():
+ if key != "time_invariant":
+ dataset_dirs[key] = os.path.join(input_dir, key)
+ ds = construct_lazy_dataset(config, dataset_dirs)
if subsample:
ds = ds.isel(time=slice(10, 13))
if check_conservation:
# these currently fail
+ ds = compute_precipitable_water_path(
+ ds,
+ output_name=standard_names.precipitable_water_path,
+ precipitable_water_path_names=PRECIPITABLE_WATER_PATH_NAMES,
+ )
assert_column_integral_of_moisture_is_conserved(
- ds, PRECIPITABLE_WATER_PATH, TOTAL_WATER_PATH
+ ds, standard_names.precipitable_water_path, standard_names.total_water_path
)
assert_global_dry_air_mass_conservation(
ds,
- dims=HORIZONTAL_DIMS,
- surface_pressure_name=SURFACE_PRESSURE,
- total_water_path_name=TOTAL_WATER_PATH,
- latitude_dim=LATITUDE_DIM,
- time_dim=TIME_DIM,
+ dims=standard_names.horizontal_dims,
+ surface_pressure_name=standard_names.surface_pressure,
+ total_water_path_name=standard_names.total_water_path,
+ latitude_dim=standard_names.latitude_dim,
+ time_dim=standard_names.time_dim,
)
assert_global_moisture_conservation(
ds,
- dims=HORIZONTAL_DIMS,
- latitude_dim=LATITUDE_DIM,
- total_water_path_name=TOTAL_WATER_PATH,
- latent_heat_flux_name=LATENT_HEAT_FLUX,
+ dims=standard_names.horizontal_dims,
+ latitude_dim=standard_names.latitude_dim,
+ total_water_path_name=standard_names.total_water_path,
+ latent_heat_flux_name=standard_names.latent_heat_flux,
latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION,
- precip_rate_name=PRECIP_RATE,
- time_dim=TIME_DIM,
+ precip_rate_name=standard_names.precip_rate,
+ time_dim=standard_names.time_dim,
)
if water_budget_dataset:
- ds = ds[WATER_BUDGET_DATASET_VARS]
+ water_budget_dataset_vars = [
+ standard_names.surface_pressure,
+ standard_names.precipitable_water_path,
+ standard_names.precip_rate,
+ standard_names.latent_heat_flux,
+ "PRECSC",
+ "PRECSL",
+ "QFLX",
+ "TMQ",
+ "TGCLDLWP",
+ "TGCLDIWP",
+ ]
+
+ ds = ds[water_budget_dataset_vars]
else:
- ds = ds.drop(DROP_VARIABLE_NAMES["POST"], errors="ignore")
+ dropped_variables = (
+ [
+ item
+ for item in standard_names.dropped_variables
+ if item.lower() != "none"
+ ]
+ + standard_names.hybrid_level_coeffs
+ + [
+ standard_names.precip_rate,
+ standard_names.vertical_interface_dim,
+ "TMQ",
+ "TGCLDLWP",
+ "TGCLDIWP",
+ "FLNS",
+ "FSNS",
+ "FSNTOA",
+ "PRECSC",
+ "PRECSL",
+ "QFLX",
+ ]
+ )
+ ds = ds.drop(dropped_variables, errors="ignore")
print(f"Output dataset size is {ds.nbytes / 1e9} GB")
if debug:
with xr.set_options(display_max_rows=500):
print(ds)
else:
ds.partition.initialize_store(output)
- for i in range(n_split):
- print(f"Writing segment {i + 1} / {n_split}")
+ for i in range(config.n_split):
+ print(f"Writing segment {i + 1} / {config.n_split}")
with ProgressBar():
ds.partition.write(
- output, n_split, ["time"], i, collect_variable_writes=True
+ output, config.n_split, ["time"], i, collect_variable_writes=True
)
diff --git a/scripts/data_process/compute_hpx_dataset.py b/scripts/data_process/compute_hpx_dataset.py
new file mode 100644
index 0000000..5556f5f
--- /dev/null
+++ b/scripts/data_process/compute_hpx_dataset.py
@@ -0,0 +1,228 @@
+# This script is used to compute a training dataset from the "raw"
+# FV3GFS data stored in zarr form on GCS.
+
+# The dependencies of this script are installed in the "fv3net" conda environment
+# which can be installed using fv3net's Makefile. See
+# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310
+
+# The resulting dataset is about 194GB (the input is about 2.5TB). Running this script
+# on my 8-CPU VM takes about 2.5 hours. See "compute_dataset_fv3gfs_argo_workflow.yaml"
+# for a workflow which parallelizes this script across the 11-member ensemble and runs
+# it on our GKE cluster.
+
+import os
+import pdb
+import sys
+from typing import Tuple
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import multiprocessing as mp
+from functools import partial
+
+import click
+import earth2grid
+import numpy as np
+import torch
+import xarray as xr
+import xpartition # noqa
+from compute_dataset import (
+ LATENT_HEAT_OF_VAPORIZATION,
+ DatasetComputationConfig,
+ DatasetConfig,
+ DLWPChunkingConfig,
+ DLWPNameMapping,
+ assert_column_integral_of_moisture_is_conserved,
+ assert_global_dry_air_mass_conservation,
+ assert_global_moisture_conservation,
+ get_dataset_urls,
+ open_datasets,
+)
+
+
+def regrid_tensor(x, regrid_func, shape):
+ data = regrid_func(torch.tensor(x, dtype=torch.double))
+ return data.numpy().reshape(shape)
+
+
+def _pool_func(ds, store, n_partitions, partition_dims, i):
+ ds.partition.write(store, n_partitions, partition_dims, i)
+ print(f"Finished writing partition {i}")
+
+
+def hpx_regrid(
+ ds: xr.Dataset,
+ dlwp_names: DLWPNameMapping,
+ level: int = 6, # regrid resolution
+ n_side: int = 64,
+) -> Tuple[xr.Dataset, xr.Dataset]:
+ lat_long_names = dlwp_names.lat_lon_dims
+ longitude = lat_long_names[1]
+ latitude = lat_long_names[0]
+ lons = ds[longitude]
+ lats = ds[latitude]
+
+ hpx = earth2grid.healpix.Grid(
+ level=level, pixel_order=earth2grid.healpix.HEALPIX_PAD_XY
+ )
+ src = earth2grid.latlon.LatLonGrid(lat=list(lats), lon=list(lons))
+ # Regridder
+ regrid = earth2grid.get_regridder(src, hpx)
+
+ ds_regridded = xr.apply_ufunc(
+ regrid_tensor,
+ ds,
+ input_core_dims=[[latitude, longitude]],
+ output_core_dims=[["face", "height", "width"]],
+ output_sizes={"face": 12, "height": n_side, "width": n_side},
+ output_dtypes=[float],
+ dask="parallelized",
+ vectorize=True,
+ on_missing_core_dim="copy",
+ kwargs={"regrid_func": regrid, "shape": (12, n_side, n_side)},
+ dask_gufunc_kwargs={"allow_rechunk": True},
+ )
+ # Assign coordinates to the regridded dataset
+ time_coords = ds.coords["time"]
+ nside_coords = np.arange(n_side)
+ grid_coords = np.arange(12)
+ ds_regridded = ds_regridded.assign_coords(
+ time=time_coords,
+ face=grid_coords,
+ height=nside_coords,
+ width=nside_coords,
+ )
+
+ return ds_regridded
+
+
+def construct_hpx_dataset(
+ config: DatasetComputationConfig,
+ run_directory: str,
+ output_directory: str,
+ toy_dataset: bool = False,
+) -> xr.Dataset:
+ dlwp_names = config.standard_names
+ if not isinstance(dlwp_names, DLWPNameMapping):
+ raise TypeError("Expected to be passed type of DLWPNameMapping.")
+ dlwp_chunking = config.chunking
+ if not isinstance(dlwp_chunking, DLWPChunkingConfig):
+ raise TypeError("Expected to be passed type of DLWPChunkingConfig.")
+
+ urls = get_dataset_urls(config, run_directory)
+ print(urls)
+ ds = open_datasets(config, urls)
+ for var in ds:
+ del ds[var].encoding["chunks"]
+ del ds[var].encoding["preferred_chunks"]
+ print(f"Input dataset size is {ds.nbytes / 1e9} GB")
+ if toy_dataset:
+ ds = ds.isel(time=slice(0, 200))
+ # We would like to:
+
+ # 1. map to healpix mesh
+ ds = hpx_regrid(
+ ds=ds,
+ dlwp_names=dlwp_names,
+ n_side=64,
+ )
+ print(f"After regrid: {ds}")
+
+ # 2. chunk and save
+ chunks = config.chunking.get_chunks(dlwp_names)
+ ds = ds.chunk(chunks)
+ ds.attrs["history"] = (
+ "Dataset computed by full-model/scripts/data_process"
+ "/compute_hpx_dataset.py"
+ f" script, using following input zarrs: {urls.values()}."
+ )
+ ds.attrs["vertical_coordinate"] = (
+ "The pressure at level interfaces can by computed as "
+ "p_i = ak_i + bk_i * PRESsfc, where PRESsfc is the surface pressure and the "
+ "p_i pressure corresponds to the interface at the top of the i'th finite "
+ "volume layer, counting down from the top of atmosphere."
+ )
+ ds = ds.rename(config.renaming)
+ return ds
+
+
+@click.command()
+@click.option("--config", help="Path to dataset configuration YAML file.")
+@click.option("--run-directory", help="Path to reference run directory.")
+@click.option("--output-store", help="Path to output zarr store.")
+@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.")
+@click.option("--subsample", is_flag=True, help="Subsample the data before writing.")
+@click.option("--check-conservation", is_flag=True, help="Check conservation.")
+@click.option("--num-processes", default=16, help="Number of processes to spin up.")
+def main(
+ config,
+ run_directory,
+ output_store,
+ debug,
+ subsample,
+ check_conservation,
+ num_processes,
+):
+ config = DatasetConfig.from_file(config).dataset_computation
+ dlwp_names = config.standard_names
+ print(f"--run-directory is {run_directory}")
+ print(f"--output-store is {output_store}")
+ ds = construct_hpx_dataset(
+ config=config,
+ run_directory=run_directory,
+ output_directory=output_store,
+ toy_dataset=False,
+ )
+ if subsample:
+ ds = ds.isel(time=slice(10, 13))
+ if check_conservation:
+ assert_column_integral_of_moisture_is_conserved(
+ ds,
+ precipitable_water_path_name=dlwp_names.precipitable_water_path,
+ total_water_path_name=dlwp_names.total_water_path,
+ )
+ assert_global_dry_air_mass_conservation(
+ ds,
+ dims=dlwp_names.horizontal_dims,
+ surface_pressure_name=dlwp_names.surface_pressure,
+ total_water_path_name=dlwp_names.total_water_path,
+ latitude_dim=dlwp_names.latitude_dim,
+ time_dim=dlwp_names.time_dim,
+ )
+ assert_global_moisture_conservation(
+ ds,
+ dims=dlwp_names.horizontal_dims,
+ latitude_dim=dlwp_names.latitude_dim,
+ total_water_path_name=dlwp_names.total_water_path,
+ latent_heat_flux_name=dlwp_names.latent_heat_flux,
+ latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION,
+ precip_rate_name=dlwp_names.precip_rate,
+ time_dim=dlwp_names.time_dim,
+ )
+ drop_vars = [var for var in dlwp_names.dropped_variables if var in ds]
+ ds = ds.drop(drop_vars)
+ print(f"Output dataset size is {ds.nbytes / 1e9} GB")
+
+ if debug:
+ with xr.set_options(display_max_rows=500):
+ print(ds)
+ else:
+ n_partitions = config.n_split
+ partition_dims = [dlwp_names.time_dim]
+ store = f"{output_store}.zarr"
+ ds.partition.initialize_store(store)
+
+ with mp.get_context("forkserver").Pool(num_processes) as pool:
+ pool.map(
+ partial(_pool_func, ds, store, n_partitions, partition_dims),
+ range(n_partitions),
+ )
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except Exception as e:
+ print(f"An error occurred: {e}", file=sys.stderr)
+ pdb.post_mortem() # Start the debugger
+ raise # Re-raise the exception to preserve the traceback
diff --git a/scripts/data_process/compute_hpx_dataset.sh b/scripts/data_process/compute_hpx_dataset.sh
new file mode 100755
index 0000000..92f874a
--- /dev/null
+++ b/scripts/data_process/compute_hpx_dataset.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+# Function to check if directory is in Python path and add it if not
+function add_to_pythonpath() {
+ local dir_to_add="$1"
+ if [[ ":$PYTHONPATH:" != *":$dir_to_add:"* ]]; then
+ export PYTHONPATH="$dir_to_add:$PYTHONPATH"
+ fi
+}
+
+# Add your own full-model directory to Python path if not already included
+add_to_pythonpath "~/full-model"
+
+ARGO=false
+
+while [[ "$#" -gt 0 ]]
+do
+ case $1 in
+ --config) CONFIG="$2"
+ shift;;
+ --argo) ARGO="$2"
+ shift;;
+ *) echo "Unknown parameter passed: $1"
+ exit 1;;
+ esac
+ shift
+done
+
+if [[ -z "${CONFIG}" ]]
+then
+ echo "Option --config missing"
+ exit 1
+fi
+
+run_directory=$(yq -r '.runs.run_directory' ${CONFIG})
+output_directory=$(yq -r '.data_output_directory' ${CONFIG})
+
+if [[ "$ARGO" == "true" ]]
+then
+ output=$(argo submit full-model/scripts/data_process/compute_hpx_dataset_argo_workflow.yaml \
+ -p python_script="$(< full-model/scripts/data_process/compute_hpx_dataset.py)" \
+ -p config="$(< ${CONFIG})" \
+ -p run_directory="${run_directory}" \
+ -p output_directory="${output_directory}")
+
+ job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}')
+ echo "Argo job submitted: $job_name"
+else
+ python3 full-model/scripts/data_process/compute_hpx_dataset.py --config="${CONFIG}" \
+ --run-directory="${run_directory}" \
+ --output-store="${output_directory}"
+fi
\ No newline at end of file
diff --git a/scripts/data_process/compute_repeating_forcing.py b/scripts/data_process/compute_repeating_forcing.py
index fdef451..4ed7659 100755
--- a/scripts/data_process/compute_repeating_forcing.py
+++ b/scripts/data_process/compute_repeating_forcing.py
@@ -97,7 +97,7 @@ def main(n_times: int, input_dir: Path, output_dir: Path, repeat_variables, nc_f
time_coord = xr.cftime_range(
ds.time.item(0),
periods=len(ds.time) * n_times,
- freq=f"{dt}H",
+ freq=f"{dt}h",
calendar=ds.time.dt.calendar,
)
diff --git a/scripts/data_process/compute_stats.sh b/scripts/data_process/compute_stats.sh
new file mode 100755
index 0000000..ceb2c2f
--- /dev/null
+++ b/scripts/data_process/compute_stats.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+./compute_dataset.sh --stats-only "$@"
diff --git a/scripts/data_process/configs/e3sm-1deg-8layer.yaml b/scripts/data_process/configs/e3sm-1deg-8layer.yaml
new file mode 100644
index 0000000..1432f89
--- /dev/null
+++ b/scripts/data_process/configs/e3sm-1deg-8layer.yaml
@@ -0,0 +1,85 @@
+runs:
+ 2024-07-10-e3smv2-1deg-testing: ""
+data_output_directory: /global/cfs/cdirs/m4492/fme-preprocess/zarr/
+stats:
+ output_directory: /global/cfs/cdirs/m4492/fme-preprocess/2024-07-10-e3smv2-1deg-testing
+ start_date: "1970-01-01"
+ end_date: "1970-12-31"
+ data_type: E3SMV2
+ beaker_dataset: e3sm-1deg-8layers-stats-1970 # this is not used in e3sm data processing
+dataset_computation:
+ chunking:
+ time_dim: 10
+ latitude_dim: 180
+ longitude_dim: 360
+ reference_vertical_coordinate_file: None
+ time_invariant_dir: /global/cfs/cdirs/m4331/jpduncan/e3smv2/time_invariant
+ vertical_coarsening_indices:
+# computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2023-06-09-e3smv2-vertical-interface-indices.ipynb
+ - [0, 19]
+ - [19, 30]
+ - [30, 38]
+ - [38, 44]
+ - [44, 48]
+ - [48, 53]
+ - [53, 61]
+ - [61, 72]
+ roundtrip_fraction_kept: 1.0
+ n_split: 100
+ variable_sources:
+ time_invariant:
+ - PHIS
+ 6hourly_instant/1yr:
+ - PS
+ - TS
+ - T
+ - U
+ - V
+ - Q
+ - CLDLIQ
+ - CLDICE
+ - RAINQM
+ - SNOWQM
+ - TMQ
+ - TGCLDLWP
+ - TGCLDIWP
+ - OCNFRAC
+ - LANDFRAC
+ - ICEFRAC
+ 6hourly/1yr:
+ - PRECT
+ - LHFLX
+ - SHFLX
+ - FLNS
+ - FLDS
+ - FSNS
+ - FSDS
+ - FSNTOA
+ - SOLIN
+ - FLUT
+ - PRECSC
+ - PRECSL
+ - QFLX
+ standard_names:
+ longitude_dim: lon
+ latitude_dim: lat
+ vertical_dim: lev
+ vertical_interface_dim: ilev
+ time_dim: time
+ surface_pressure: PS
+ latent_heat_flux: LHFLX
+ precip_rate: PRECT
+ precipitable_water_path: precipitable_water_path
+ pressure_thickness: pressure_thickness_of_atmospheric_layer
+ air_temperature: T
+ specific_humidity: Q
+ cloud_water_mixing_ratio: CLDLIQ
+ cloud_ice_mixing_ratio: CLDICE
+ graupel_mixing_ratio: None
+ rain_mixing_ratio: RAINQM
+ snow_mixing_ratio: SNOWQM
+ northward_wind: V
+ eastward_wind: U
+ hybrid_level_coeffs:
+ - hyai
+ - hybi
\ No newline at end of file
diff --git a/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml b/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml
new file mode 100644
index 0000000..b01df38
--- /dev/null
+++ b/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml
@@ -0,0 +1,9 @@
+runs:
+ 2024-07-11-era5-1deg-16layer-1940-2022: "" # no real data source, config only for computing stats
+data_output_directory: gs://vcm-ml-intermediate
+stats:
+ output_directory: gs://vcm-ml-intermediate/era5-1deg-16layer-stats-1990-2019
+ beaker_dataset: era5-1deg-16layer-stats-1990-2019
+ start_date: "1990-01-01"
+ end_date: "2019-12-31"
+ data_type: ERA5
diff --git a/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml b/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml
new file mode 100644
index 0000000..b200418
--- /dev/null
+++ b/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml
@@ -0,0 +1,9 @@
+runs:
+ 2024-06-20-era5-1deg-8layer-1940-2022: "" # no real data source, config only for computing stats
+data_output_directory: gs://vcm-ml-intermediate
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-06-20-era5-1deg-8layer-stats-1990-2019
+ beaker_dataset: era5-1deg-8layer-stats-1990-2019-v2
+ start_date: "1990-01-01"
+ end_date: "2019-12-31"
+ data_type: ERA5
diff --git a/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml
new file mode 100644
index 0000000..ef7ac9e
--- /dev/null
+++ b/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml
@@ -0,0 +1,81 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0002
+ ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0003
+ ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0004
+data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats
+ beaker_dataset: 2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats
+ start_date: "1990-01-01"
+ end_date: "2019-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0004"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ - [18, 26]
+ - [26, 31]
+ - [31, 36]
+ - [36, 41]
+ - [41, 47]
+ - [47, 53]
+ - [53, 63]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ fourcastnet_vanilla.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - TMP2m
+ - UGRD10m
+ - VGRD10m
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture
+ - specific_humidity_at_two_meters
+ encoded_surface_type.zarr:
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
diff --git a/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml
new file mode 100644
index 0000000..d329622
--- /dev/null
+++ b/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml
@@ -0,0 +1,81 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0002
+ ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0003
+ ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0004
+data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats
+ beaker_dataset: 2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats
+ start_date: "1990-01-01"
+ end_date: "2019-12-31"
+ exclude_runs:
+ - "ic_0004"
+ data_type: FV3GFS
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ - [18, 26]
+ - [26, 31]
+ - [31, 36]
+ - [36, 41]
+ - [41, 47]
+ - [47, 53]
+ - [53, 63]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ fourcastnet_vanilla.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - TMP2m
+ - UGRD10m
+ - VGRD10m
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture
+ - specific_humidity_at_two_meters
+ encoded_surface_type.zarr:
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
diff --git a/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml
new file mode 100644
index 0000000..fdaff68
--- /dev/null
+++ b/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml
@@ -0,0 +1,77 @@
+runs:
+ ic_0011: gs://vcm-ml-raw-flexible-retention/2023-08-03-C48-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0011_2021010100
+data_output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats
+ beaker_dataset: 2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats
+ start_date: "2021-01-01"
+ end_date: "2030-12-31"
+ data_type: FV3GFS
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ - [18, 26]
+ - [26, 31]
+ - [31, 36]
+ - [36, 41]
+ - [41, 47]
+ - [47, 53]
+ - [53, 63]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ fourcastnet_vanilla.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - TMP2m
+ - UGRD10m
+ - VGRD10m
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture
+ - specific_humidity_at_two_meters
+ encoded_surface_type.zarr:
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ roundtrip_fraction_kept: 0.65
diff --git a/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml
new file mode 100644
index 0000000..08a8379
--- /dev/null
+++ b/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml
@@ -0,0 +1,88 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0002
+ ic_0003: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0003
+ ic_0004: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0004
+ ic_0005: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0005
+ ic_0006: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0006
+ ic_0007: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0007
+ ic_0008: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0008
+ ic_0009: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0009
+ ic_0010: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0010
+ ic_0011: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0011
+data_output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset-stats
+ beaker_dataset: 2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset-stats
+ start_date: "2021-01-01"
+ end_date: "2030-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0011"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ - [18, 26]
+ - [26, 31]
+ - [31, 36]
+ - [36, 41]
+ - [41, 47]
+ - [47, 53]
+ - [53, 63]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ fourcastnet_vanilla.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - TMP2m
+ - UGRD10m
+ - VGRD10m
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture
+ - specific_humidity_at_two_meters
+ encoded_surface_type.zarr:
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
diff --git a/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml
new file mode 100644
index 0000000..22a0c2b
--- /dev/null
+++ b/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml
@@ -0,0 +1,88 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0002
+ ic_0003: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0003
+ ic_0004: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0004
+ ic_0005: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0005
+ ic_0006: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0006
+ ic_0007: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0007
+ ic_0008: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0008
+ ic_0009: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0009
+ ic_0010: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0010
+ ic_0011: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0011
+data_output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset-stats
+ beaker_dataset: 2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset-stats
+ start_date: "2021-01-01"
+ end_date: "2030-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0011"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ - [18, 26]
+ - [26, 31]
+ - [31, 36]
+ - [36, 41]
+ - [41, 47]
+ - [47, 53]
+ - [53, 63]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ fourcastnet_vanilla.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - TMP2m
+ - UGRD10m
+ - VGRD10m
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture
+ - specific_humidity_at_two_meters
+ encoded_surface_type.zarr:
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
diff --git a/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml b/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml
new file mode 100644
index 0000000..0f56b31
--- /dev/null
+++ b/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml
@@ -0,0 +1,119 @@
+runs:
+ run_directory: /mntdata
+stats:
+ output_directory: /mntdata/2024-08-21-healpix-era5-dataset
+ beaker_dataset: 2024-08-21-healpix-era5-dataset-stats
+ start_date: "1990-01-01"
+ end_date: "2019-12-31"
+ data_type: ERA5
+data_output_directory: /mntdata/2024-08-21-healpix-era5-dataset
+dataset_computation:
+ n_split: 400
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 18]
+ standard_names:
+ face_dim: "face"
+ chunking:
+ face_dim: -1
+ variable_sources:
+ 2024-06-20-era5-1deg-8layer-1940-2022.zarr:
+ - DLWRFsfc
+ - DPT2m
+ - DSWRFsfc
+ - DSWRFtoa
+ - HGTsfc
+ - LHTFLsfc
+ - PRATEsfc
+ - PRESsfc
+ - Q200
+ - Q2m
+ - Q500
+ - Q850
+ - SHTFLsfc
+ - TMP200
+ - TMP2m
+ - TMP500
+ - TMP850
+ - UGRD10m
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - VGRD10m
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - air_temperature_0
+ - air_temperature_1
+ - air_temperature_2
+ - air_temperature_3
+ - air_temperature_4
+ - air_temperature_5
+ - air_temperature_6
+ - air_temperature_7
+ - ak_0
+ - ak_1
+ - ak_2
+ - ak_3
+ - ak_4
+ - ak_5
+ - ak_6
+ - ak_7
+ - ak_8
+ - bk_0
+ - bk_1
+ - bk_2
+ - bk_3
+ - bk_4
+ - bk_5
+ - bk_6
+ - bk_7
+ - bk_8
+ - eastward_wind_0
+ - eastward_wind_1
+ - eastward_wind_2
+ - eastward_wind_3
+ - eastward_wind_4
+ - eastward_wind_5
+ - eastward_wind_6
+ - eastward_wind_7
+ - h1000
+ - h200
+ - h250
+ - h300
+ - h500
+ - h700
+ - h850
+ - land_fraction
+ - latitude
+ - longitude
+ - northward_wind_0
+ - northward_wind_1
+ - northward_wind_2
+ - northward_wind_3
+ - northward_wind_4
+ - northward_wind_5
+ - northward_wind_6
+ - northward_wind_7
+ - ocean_fraction
+ - sea_ice_fraction
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - specific_total_water_0
+ - specific_total_water_1
+ - specific_total_water_2
+ - specific_total_water_3
+ - specific_total_water_4
+ - specific_total_water_5
+ - specific_total_water_6
+ - specific_total_water_7
+ - surface_temperature
+ - tendency_of_total_water_path_due_to_advection
+ - time
+ - total_column_water_vapour
\ No newline at end of file
diff --git a/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml b/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml
new file mode 100644
index 0000000..9342e6f
--- /dev/null
+++ b/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml
@@ -0,0 +1,71 @@
+runs:
+ 2024-09-20-cm4-1deg-8layer-trial-run: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/regridded-zarrs/gaussian_grid_180_by_360/trial-run
+data_output_directory: gs://vcm-ml-intermediate
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-09-20-cm4-1deg-8layer-trial-run-stats
+ start_date: "0151-01-01"
+ end_date: "0159-01-01"
+ data_type: CM4
+ beaker_dataset: not-used
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ # computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2024-08-13-pre-industiral-CM4-eda/2024-08-28-AM4-vertical-indices.ipynb
+ - [0, 7]
+ - [7, 10]
+ - [10, 13]
+ - [13, 16]
+ - [16, 18]
+ - [18, 22]
+ - [22, 25]
+ - [25, 33]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - eastward_surface_wind_stress
+ - northward_surface_wind_stress
+ - surface_evaporation_rate
+ - total_energy
+ - total_frozen_precipitation_rate
+ full_state.zarr:
+ # 2D vars
+ - HGTsfc # static
+ - PRESsfc
+ - surface_temperature
+ - air_temperature_at_two_meters
+ - specific_humidity_at_two_meters
+ - eastward_wind_at_ten_meters
+ - northward_wind_at_ten_meters
+ # 3D vars:
+ - air_temperature
+ - specific_humidity # water species
+ - cloud_water_mixing_ratio # water species
+ - cloud_ice_mixing_ratio # water species
+ - eastward_wind
+ - northward_wind
+ land_static.zarr:
+ - land_fraction
+ full_state_land.zarr:
+ - column_soil_moisture
+ full_state_ice.zarr:
+ - sea_ice_fraction
+ standard_names:
+ longitude_dim: lon
+ latitude_dim: lat
+ graupel_mixing_ratio: None
+ rain_mixing_ratio: None
+ snow_mixing_ratio: None
+ precipitable_water_path: None
diff --git a/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml
new file mode 100644
index 0000000..24103f8
--- /dev/null
+++ b/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml
@@ -0,0 +1,90 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0002
+data_output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats
+ beaker_dataset: 2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats
+ start_date: "1940-01-01"
+ end_date: "2021-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0002"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - air_temperature_at_two_meters
+ - eastward_wind_at_ten_meters
+ - northward_wind_at_ten_meters
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - snow_cover_fraction
+ - specific_humidity_at_two_meters
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - UGRD200
+ - VGRD200
+ - TMP200
+ - RH200
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..c946ed0
--- /dev/null
+++ b/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml
@@ -0,0 +1,91 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0002
+data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats
+ beaker_dataset: 2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats
+ start_date: "1940-01-01"
+ end_date: "2021-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0002"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - UGRD1000
+ - VGRD500
+ - VGRD850
+ - VGRD1000
+ - h50
+ - h500
+ - h850
+ - h1000
+ - air_temperature_at_two_meters
+ - eastward_wind_at_ten_meters
+ - northward_wind_at_ten_meters
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - snow_cover_fraction
+ - specific_humidity_at_two_meters
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ UGRD200_VGRD200_TMP200_RH200.zarr:
+ - UGRD200
+ - VGRD200
+ - TMP200
+ - RH200
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml
new file mode 100644
index 0000000..a2ce8e6
--- /dev/null
+++ b/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml
@@ -0,0 +1,90 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0002
+data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats
+ beaker_dataset: 2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats
+ start_date: "1940-01-01"
+ end_date: "2021-12-31"
+ data_type: FV3GFS
+ exclude_runs:
+ - "ic_0002"
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - PRESsfc
+ - HGTsfc
+ - RH500
+ - RH850
+ - TMP500
+ - TMP850
+ - UGRD500
+ - UGRD850
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ - air_temperature_at_two_meters
+ - eastward_wind_at_ten_meters
+ - northward_wind_at_ten_meters
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - column_soil_moisture
+ - snow_cover_fraction
+ - specific_humidity_at_two_meters
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ UGRD200_VGRD200_TMP200_RH200.zarr:
+ - UGRD200
+ - VGRD200
+ - TMP200
+ - RH200
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml
new file mode 100644
index 0000000..fd12e65
--- /dev/null
+++ b/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml
@@ -0,0 +1,89 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0001
+ ic_0002: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0002
+ ic_0003: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0003
+ ic_0004: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0004
+ ic_0005: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0005
+ ic_0006: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0006
+ ic_0007: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0007
+ ic_0008: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0008
+ ic_0009: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0009
+ ic_0010: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0010
+ ic_0011: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0011
+ ic_0012: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0012
+ ic_0013: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0013
+ ic_0014: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0014
+ ic_0015: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0015
+ ic_0016: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0016
+ ic_0017: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0017
+ ic_0018: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0018
+ ic_0019: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0019
+ ic_0020: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0020
+ ic_0021: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0021
+data_output_directory: gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset-stats
+ beaker_dataset: 2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset-stats
+ start_date: "2021-01-01"
+ end_date: "2030-12-31"
+ exclude_runs:
+ - "ic_0021"
+ data_type: FV3GFS
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
\ No newline at end of file
diff --git a/scripts/data_process/configs/shield-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-c96-4deg-8layer.yaml
new file mode 100644
index 0000000..23a5116
--- /dev/null
+++ b/scripts/data_process/configs/shield-c96-4deg-8layer.yaml
@@ -0,0 +1,68 @@
+runs:
+ ic_0001: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/repeating-sst
+data_output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats
+ beaker_dataset: 2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats
+ # start_date: "2035-01-01" # start of run
+ start_date: "2036-01-01" # we exclude just the first year so we can use it as an initial condition
+ end_date: "2060-12-31" # end of run
+ data_type: FV3GFS
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
\ No newline at end of file
diff --git a/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..159d23c
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml
@@ -0,0 +1,94 @@
+runs:
+ abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-2xCO2
+ abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-3xCO2
+ abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-4xCO2
+data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats
+ beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ - abrupt-2xCO2
+ - abrupt-3xCO2
+ - abrupt-4xCO2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml
new file mode 100644
index 0000000..746f873
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml
@@ -0,0 +1,94 @@
+runs:
+ abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2
+ abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2
+ abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2
+data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats
+ beaker_dataset: 2024-08-14-vertically-resolved-4deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ - abrupt-2xCO2
+ - abrupt-3xCO2
+ - abrupt-4xCO2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml
new file mode 100644
index 0000000..ffbffa0
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml
@@ -0,0 +1,140 @@
+runs:
+ 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001
+ 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002
+ 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003
+ 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004
+ 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005
+ 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001
+ 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002
+ 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003
+ 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004
+ 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005
+ 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001
+ 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002
+ 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003
+ 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004
+ 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005
+ 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001
+ 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002
+ 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003
+ 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004
+ 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005
+ abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2
+ abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2
+ abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2
+ increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2
+data_output_directory: gs://vcm-ml-intermediate/2024-07-17-vertically-resolved-4deg-c24-shield-som-baseline-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-17-vertically-resolved-4deg-c24-shield-som-ensemble-fme-dataset-stats
+ beaker_dataset: 2024-07-17-vertically-resolved-4deg-fme-c24-shield-som-ensemble-dataset-stats
+ # These datasets already exclude the one-year divergence period for each
+ # ensemble member, so we can use data from all times when computing the
+ # stats.
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ # Exclude all runs here since we are not interested in these for ML training
+ - 1xCO2-ic_0001
+ - 1xCO2-ic_0002
+ - 1xCO2-ic_0003
+ - 1xCO2-ic_0004
+ - 1xCO2-ic_0005
+ - 2xCO2-ic_0001
+ - 2xCO2-ic_0002
+ - 2xCO2-ic_0003
+ - 2xCO2-ic_0004
+ - 2xCO2-ic_0005
+ - 3xCO2-ic_0001
+ - 3xCO2-ic_0002
+ - 3xCO2-ic_0003
+ - 3xCO2-ic_0004
+ - 3xCO2-ic_0005
+ - 4xCO2-ic_0001
+ - 4xCO2-ic_0002
+ - 4xCO2-ic_0003
+ - 4xCO2-ic_0004
+ - 4xCO2-ic_0005
+ - abrupt-2xCO2
+ - abrupt-3xCO2
+ - abrupt-4xCO2
+ - increasing-co2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml
new file mode 100644
index 0000000..fe93ebe
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml
@@ -0,0 +1,140 @@
+runs:
+ 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001
+ 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002
+ 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003
+ 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004
+ 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005
+ 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001
+ 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002
+ 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003
+ 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004
+ 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005
+ 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001
+ 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002
+ 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003
+ 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004
+ 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005
+ 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001
+ 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002
+ 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003
+ 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004
+ 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005
+ abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2
+ abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2
+ abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2
+ increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2
+data_output_directory: gs://vcm-ml-intermediate/2024-11-12-vertically-resolved-4deg-c24-shield-som-tuned-cdmbgwd-baseline-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-11-12-vertically-resolved-4deg-c24-shield-som-tuned-cdmbgwd-ensemble-fme-dataset-stats
+ beaker_dataset: 2024-11-12-vertically-resolved-4deg-fme-c24-shield-som-tuned-cdmbgwd-ensemble-dataset-stats
+ # These datasets already exclude the one-year divergence period for each
+ # ensemble member, so we can use data from all times when computing the
+ # stats.
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ # Exclude all runs here since we are not interested in these for ML training
+ - 1xCO2-ic_0001
+ - 1xCO2-ic_0002
+ - 1xCO2-ic_0003
+ - 1xCO2-ic_0004
+ - 1xCO2-ic_0005
+ - 2xCO2-ic_0001
+ - 2xCO2-ic_0002
+ - 2xCO2-ic_0003
+ - 2xCO2-ic_0004
+ - 2xCO2-ic_0005
+ - 3xCO2-ic_0001
+ - 3xCO2-ic_0002
+ - 3xCO2-ic_0003
+ - 3xCO2-ic_0004
+ - 3xCO2-ic_0005
+ - 4xCO2-ic_0001
+ - 4xCO2-ic_0002
+ - 4xCO2-ic_0003
+ - 4xCO2-ic_0004
+ - 4xCO2-ic_0005
+ - abrupt-2xCO2
+ - abrupt-3xCO2
+ - abrupt-4xCO2
+ - increasing-co2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..8bb3618
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml
@@ -0,0 +1,121 @@
+runs:
+ 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0001
+ 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0002
+ 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0003
+ 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0004
+ 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0005
+ 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0001
+ 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0002
+ 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0003
+ 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0004
+ 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0005
+ 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0001
+ 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0002
+ 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0003
+ 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0004
+ 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0005
+ 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0001
+ 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0002
+ 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0003
+ 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0004
+ 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0005
+data_output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-1deg-c96-shield-som-ensemble-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-1deg-c96-shield-som-ensemble-fme-dataset-stats
+ beaker_dataset: 2024-07-09-vertically-resolved-1deg-fme-c96-shield-som-ensemble-dataset-stats
+ # These datasets already exclude the one-year divergence period for each
+ # ensemble member, so we can use data from all times when computing the
+ # stats.
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ # In sample validation data
+ - 1xCO2-ic_0005
+ - 2xCO2-ic_0005
+ - 4xCO2-ic_0005
+ # Out of sample equilibrium climate data
+ - 3xCO2-ic_0001
+ - 3xCO2-ic_0002
+ - 3xCO2-ic_0003
+ - 3xCO2-ic_0004
+ - 3xCO2-ic_0005
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml
new file mode 100644
index 0000000..1df2bfe
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml
@@ -0,0 +1,127 @@
+runs:
+ 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001
+ 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002
+ 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003
+ 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004
+ 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005
+ 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001
+ 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002
+ 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003
+ 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004
+ 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005
+ 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001
+ 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002
+ # For expediency we did not regrid these ensemble members to 4 degree
+ # resolution. Where needed, we regrid after computing a time and ensemble mean
+ # with the one degree data.
+ # 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003
+ # 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004
+ # 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005
+ 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001
+ 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002
+ 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003
+ 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004
+ 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005
+data_output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-4deg-c96-shield-som-ensemble-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-4deg-c96-shield-som-ensemble-fme-dataset-stats
+ beaker_dataset: 2024-07-09-vertically-resolved-4deg-fme-c96-shield-som-ensemble-dataset-stats
+ # These datasets already exclude the one-year divergence period for each
+ # ensemble member, so we can use data from all times when computing the
+ # stats.
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ # In sample validation data
+ - 1xCO2-ic_0005
+ - 2xCO2-ic_0005
+ - 4xCO2-ic_0005
+ # Out of sample equilibrium climate data
+ - 3xCO2-ic_0001
+ - 3xCO2-ic_0002
+ # For expediency we did not regrid these ensemble members to 4 degree
+ # resolution. Where needed, we regrid after computing a time and ensemble
+ # mean with the one degree data.
+ # - 3xCO2-ic_0003
+ # - 3xCO2-ic_0004
+ # - 3xCO2-ic_0005
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..843f433
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml
@@ -0,0 +1,90 @@
+runs:
+ increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/increasing-CO2
+data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset-stats
+ beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-increasing-co2-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ - increasing-CO2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml
new file mode 100644
index 0000000..d82256e
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml
@@ -0,0 +1,90 @@
+runs:
+ increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2
+data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset-stats
+ beaker_dataset: 2024-07-16-vertically-resolved-4deg-fme-c96-shield-som-increasing-co2-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ - increasing-CO2
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..19d1366
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml
@@ -0,0 +1,90 @@
+runs:
+ radiation-multi-call: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/radiation-multi-call
+data_output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset-stats
+ beaker_dataset: 2024-10-22-vertically-resolved-1deg-fme-c96-radiation-multi-call-co2-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ - radiation-multi-call
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml
new file mode 100644
index 0000000..b6b5c64
--- /dev/null
+++ b/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml
@@ -0,0 +1,97 @@
+runs:
+ 1xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-spin-up-ic_0005
+ 2xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-spin-up-ic_0005
+ 3xCO2-spin-up-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-spin-up-ic_0002
+ 4xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-spin-up-ic_0005
+data_output_directory: gs://vcm-ml-intermediate/2024-08-15-vertically-resolved-1deg-c96-shield-som-ensemble-spin-up-fme-dataset
+stats:
+ output_directory: gs://vcm-ml-intermediate/2024-08-15-vertically-resolved-1deg-c96-shield-som-ensemble-spin-up-fme-dataset-stats
+ beaker_dataset: 2024-08-15-vertically-resolved-1deg-fme-c96-shield-som-ensemble-spin-up-dataset-stats
+ start_date: null
+ end_date: null
+ data_type: FV3GFS
+ exclude_runs:
+ # In sample validation data
+ - 1xCO2-spin-up-ic_0005
+ - 2xCO2-spin-up-ic_0005
+ - 3xCO2-spin-up-ic_0002
+ - 4xCO2-spin-up-ic_0005
+dataset_computation:
+ reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc
+ vertical_coarsening_indices:
+ - [0, 11]
+ - [11, 21]
+ - [21, 30]
+ - [30, 39]
+ - [39, 49]
+ - [49, 58]
+ - [58, 67]
+ - [67, 79]
+ renaming:
+ specific_humidity_at_two_meters: Q2m
+ air_temperature_at_two_meters: TMP2m
+ eastward_wind_at_ten_meters: UGRD10m
+ northward_wind_at_ten_meters: VGRD10m
+ variable_sources:
+ fluxes_2d.zarr:
+ - PRATEsfc
+ - LHTFLsfc
+ - SHTFLsfc
+ - DLWRFsfc
+ - DSWRFsfc
+ - DSWRFtoa
+ - ULWRFsfc
+ - ULWRFtoa
+ - USWRFsfc
+ - USWRFtoa
+ - precipitable_water_path
+ - GRAUPELsfc
+ - ICEsfc
+ - SNOWsfc
+ full_state.zarr:
+ - surface_temperature
+ - air_temperature
+ - specific_humidity
+ - cloud_water_mixing_ratio
+ - cloud_ice_mixing_ratio
+ - graupel_mixing_ratio
+ - rain_mixing_ratio
+ - snow_mixing_ratio
+ - northward_wind
+ - eastward_wind
+ - pressure_thickness_of_atmospheric_layer
+ - PRESsfc
+ - HGTsfc
+ - column_soil_moisture
+ - soil_moisture_0
+ - soil_moisture_1
+ - soil_moisture_2
+ - soil_moisture_3
+ - land_fraction
+ - ocean_fraction
+ - sea_ice_fraction
+ - specific_humidity_at_two_meters
+ - air_temperature_at_two_meters
+ - northward_wind_at_ten_meters
+ - eastward_wind_at_ten_meters
+ - RH200
+ - RH500
+ - RH850
+ - TMP200
+ - TMP500
+ - TMP850
+ - UGRD200
+ - UGRD500
+ - UGRD850
+ - VGRD200
+ - VGRD500
+ - VGRD850
+ - h50
+ - h200
+ - h500
+ - h850
+ ocean_forcing.zarr:
+ - prescribed_mixed_layer_depth
+ - prescribed_qflux
+ scalar.zarr:
+ - global_mean_co2
diff --git a/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh b/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh
new file mode 100755
index 0000000..638901a
--- /dev/null
+++ b/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# This script launches N_IC jobs to convert all GCS zarr data to local monthly netCDFs
+
+while [[ "$#" -gt 0 ]]
+do case $1 in
+ --input-url) BASE_INPUT_URL="$2"
+ shift;;
+ --n-ic) N_IC=$2
+ shift;;
+ --output-dir) BASE_OUTPUT_DIR="$2"
+ shift;;
+ --start-date) START_DATE="$2"
+ shift;;
+ --end-date) END_DATE="$2"
+ shift;;
+ *) echo "Unknown parameter passed: $1"
+ exit 1;;
+esac
+shift
+done
+
+if [[ -z "${BASE_INPUT_URL}" ]]
+then
+ echo "Option --input-url missing"
+ exit 1;
+elif [[ -z "${N_IC}" ]]
+then
+ echo "Option --n-ic missing"
+ exit 1;
+elif [[ -z "${BASE_OUTPUT_DIR}" ]]
+then
+ echo "Option --output-dir missing"
+ exit 1;
+elif [[ -z "${START_DATE}" ]]
+then
+ echo "Option --start-date missing"
+ exit 1;
+elif [[ -z "${END_DATE}" ]]
+then
+ echo "Option --end-date missing"
+ exit 1;
+fi
+
+
+
+for IC in $(seq 1 $(( N_IC ))); do
+ IC_STR=$(printf "%04d" ${IC})
+ INPUT_URL=${BASE_INPUT_URL}/ic_${IC_STR}.zarr
+ OUTPUT_DIR=${BASE_OUTPUT_DIR}/ic_${IC_STR}
+ python convert_to_monthly_netcdf.py \
+ $INPUT_URL \
+ $OUTPUT_DIR \
+ --start-date $START_DATE \
+ --end-date $END_DATE &
+done
diff --git a/scripts/data_process/earth2grid.Dockerfile b/scripts/data_process/earth2grid.Dockerfile
new file mode 100644
index 0000000..e8fd37c
--- /dev/null
+++ b/scripts/data_process/earth2grid.Dockerfile
@@ -0,0 +1,11 @@
+FROM nvcr.io/nvidia/pytorch:23.08-py3
+
+# Clone and install earth2grid if not already installed
+RUN PACKAGE=earth2grid && \
+ if ! pip show "$PACKAGE" &>/dev/null; then \
+ git clone https://github.com/NVlabs/earth2grid.git && \
+ cd earth2grid && \
+ pip install --no-build-isolation . && \
+ cd .. && \
+ rm -rf earth2grid; \
+ fi
\ No newline at end of file
diff --git a/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh b/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh
new file mode 100755
index 0000000..a440b5d
--- /dev/null
+++ b/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+set -e
+
+while [[ "$#" -gt 0 ]]
+do case $1 in
+ --input-url) INPUT_URL="$2"
+ shift;;
+ --start-date) START_DATE="$2"
+ shift;;
+ --end-date) END_DATE="$2"
+ shift;;
+ --name) DATASET_NAME="$2"
+ shift;;
+ --desc) DATASET_DESC="$2"
+ shift;;
+ --script-flags) SCRIPT_FLAGS="$2"
+ shift;;
+ *) echo "Unknown parameter passed: $1"
+ exit 1;;
+esac
+shift
+done
+
+if [[ -z "${INPUT_URL}" ]]
+then
+ echo "Option --input-url missing"
+ exit 1;
+elif [[ -z "${START_DATE}" ]]
+then
+ echo "Option --start-date missing"
+ exit 1;
+elif [[ -z "${END_DATE}" ]]
+then
+ echo "Option --end-date missing"
+ exit 1;
+elif [[ -z "${DATASET_NAME}" ]]
+then
+ echo "Option --dataset-name missing"
+ exit 1;
+fi
+
+OUTPUT_DIR="/tmp/$(uuidgen)"
+
+python get_stats.py \
+ $INPUT_URL \
+ ${OUTPUT_DIR} \
+ --start-date $START_DATE \
+ --end-date $END_DATE ${SCRIPT_FLAGS}
+
+beaker dataset create ${OUTPUT_DIR} \
+ --name ${DATASET_NAME} --desc "${DATASET_DESC}"
+
+rm -rf ${OUTPUT_DIR}
diff --git a/scripts/data_process/generate_datasets_e3smv2.sh b/scripts/data_process/generate_datasets_e3smv2.sh
new file mode 100755
index 0000000..056204f
--- /dev/null
+++ b/scripts/data_process/generate_datasets_e3smv2.sh
@@ -0,0 +1,89 @@
+#!/bin/bash -l
+
+#SBATCH -A m4331
+#SBATCH -q regular
+#SBATCH -C cpu
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH -t 01:00:00
+#SBATCH --output=joblogs/%x-%j.out
+
+while [[ "$#" -gt 0 ]]
+do case $1 in
+ -i|--input-dir) INPUT_DIR="$2"
+ shift;;
+ -c|--config) CONFIG="$2"
+ shift;;
+ -z|--zarr) ZARR="$2"
+ shift;;
+ -o|--output-dir) OUTPUT_DIR="$2"
+ shift;;
+ *) echo "Unknown parameter passed: $1"
+ exit 1;;
+esac
+shift
+done
+
+if [[ -z "${INPUT_DIR}" ]]
+then
+ echo "Option -i, --input-dir missing"
+ exit 1;
+elif [[ -z "${CONFIG}" ]]
+then
+ echo "Option -c, --config missing"
+ exit 1;
+elif [[ -z "${ZARR}" ]]
+then
+ echo "Option -z, --zarr missing"
+ exit 1;
+elif [[ -z "${OUTPUT_DIR}" ]]
+then
+ echo "Option -o, --output-dir missing"
+ exit 1;
+fi
+
+# output dir should be somewhere on $SCRATCH, even if the intention is to send
+# the data to CFS or some external data store
+mkdir -p $OUTPUT_DIR
+
+# stripe_small is recommended for files of size 1-10 GB on Perlmutter's Lustre
+# scratch filesystem and stripes across 8 OSTs
+# see https://docs.nersc.gov/performance/io/lustre/#nersc-file-striping-recommendations
+stripe_small $OUTPUT_DIR
+
+# NOTE: assumes you've already created the fv3net conda env. See
+# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310
+source activate fv3net
+
+set -xe
+
+# create the zarr from E3SMv2 .nc files
+python -u compute_dataset_e3smv2.py --n-workers=16 --config=${CONFIG} \
+ -i ${INPUT_DIR} -o ${ZARR}
+
+# Train on first year (intended for training)
+python -u convert_to_monthly_netcdf.py \
+ ${ZARR} \
+ ${OUTPUT_DIR}/traindata \
+ --start-date 1970-01-01 \
+ --end-date 1970-12-31 \
+ --nc-format NETCDF4
+
+# Validation on next 6 months
+python -u convert_to_monthly_netcdf.py \
+ ${ZARR} \
+ ${OUTPUT_DIR}/validdata \
+ --start-date 1971-01-01 \
+ --end-date 1971-05-31 \
+ --nc-format NETCDF4
+
+# Final 6 months for preditiondata reference
+python -u convert_to_monthly_netcdf.py \
+ ${ZARR} \
+ ${OUTPUT_DIR}/predictiondata \
+ --start-date 1971-06-01 \
+ --end-date 1971-12-31 \
+ --nc-format NETCDF4
+
+# compute all stats on training data
+python -u get_stats.py ${CONFIG} 0
diff --git a/scripts/data_process/get_stats.py b/scripts/data_process/get_stats.py
index 2a374df..448e787 100644
--- a/scripts/data_process/get_stats.py
+++ b/scripts/data_process/get_stats.py
@@ -38,12 +38,13 @@
"FV3GFS": ["time", "grid_xt", "grid_yt"],
"E3SMV2": ["time", "lat", "lon"],
"ERA5": ["time", "latitude", "longitude"],
+ "CM4": ["time", "lat", "lon"],
}
def add_history_attrs(ds, input_zarr, start_date, end_date, n_samples):
ds.attrs["history"] = (
- "Created by ace/fv3gfs_data_process/get_stats.py. INPUT_ZARR:"
+ "Created by full-model/fv3gfs_data_process/get_stats.py. INPUT_ZARR:"
f" {input_zarr}, START_DATE: {start_date}, END_DATE: {end_date}."
)
ds.attrs["input_samples"] = n_samples
@@ -64,10 +65,11 @@ def copy(source: str, destination: str):
@dataclasses.dataclass
class StatsConfig:
output_directory: str
- data_type: Literal["FV3GFS", "E3SMV2", "ERA5"]
+ data_type: Literal["FV3GFS", "E3SMV2", "ERA5", "CM4"]
exclude_runs: List[str] = dataclasses.field(default_factory=list)
start_date: Optional[str] = None
end_date: Optional[str] = None
+ beaker_dataset: Optional[str] = None
@dataclasses.dataclass
diff --git a/scripts/data_process/test_config.py b/scripts/data_process/test_config.py
new file mode 100644
index 0000000..aa24d30
--- /dev/null
+++ b/scripts/data_process/test_config.py
@@ -0,0 +1,27 @@
+import os
+
+import dacite
+import pytest
+import yaml
+from combine_stats import Config as CombineStatsConfig
+from get_stats import Config as GetStatsConfig
+from upload_stats import Config as UploadStatsConfig
+
+DIRNAME = os.path.abspath(os.path.dirname(__file__))
+# list files in DIRNAME/config
+CONFIG_YAMLS = [
+ os.path.join(DIRNAME + "/configs", f)
+ for f in os.listdir(DIRNAME + "/configs")
+ if f.endswith(".yaml")
+]
+
+
+@pytest.mark.parametrize(
+ "filename",
+ CONFIG_YAMLS,
+)
+@pytest.mark.parametrize("cls", [GetStatsConfig, UploadStatsConfig, CombineStatsConfig])
+def test_get_stats_valid(filename, cls):
+ with open(filename, "r") as f:
+ config_data = yaml.load(f, Loader=yaml.CLoader)
+ dacite.from_dict(data_class=cls, data=config_data)
diff --git a/scripts/data_process/upload_stats.py b/scripts/data_process/upload_stats.py
new file mode 100644
index 0000000..863ac91
--- /dev/null
+++ b/scripts/data_process/upload_stats.py
@@ -0,0 +1,89 @@
+import dataclasses
+import shutil
+import tempfile
+from typing import Dict, List, Optional
+
+import click
+import dacite
+import fsspec
+import yaml
+
+
+def copy(source: str, destination: str):
+ """Copy between any two 'filesystems'. Do not use for large files.
+
+ Args:
+ source: Path to source file/object.
+ destination: Path to destination.
+ """
+ with fsspec.open(source) as f_source:
+ with fsspec.open(destination, "wb") as f_destination:
+ shutil.copyfileobj(f_source, f_destination)
+
+
+@dataclasses.dataclass
+class StatsConfig:
+ output_directory: str
+ beaker_dataset: str
+ exclude_runs: List[str] = dataclasses.field(default_factory=list)
+ start_date: Optional[str] = None
+ end_date: Optional[str] = None
+
+
+@dataclasses.dataclass
+class Config:
+ runs: Dict[str, str]
+ data_output_directory: str
+ stats: StatsConfig
+
+
+@click.command()
+@click.argument("config_yaml", type=str)
+def main(config_yaml: str):
+ """
+ Combine statistics for the data processing pipeline.
+
+ Arguments:
+ config_yaml -- Path to the configuration file for the data processing pipeline.
+ """
+ # imported here so we don't need to install beaker for the tests
+ from beaker import Beaker
+
+ with open(config_yaml, "r") as f:
+ config_data = yaml.load(f, Loader=yaml.CLoader)
+ config = dacite.from_dict(data_class=Config, data=config_data)
+
+ stats_combined_dir = config.stats.output_directory + "/combined/"
+ beaker = Beaker.from_env()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ for filename in (
+ "centering.nc",
+ "scaling-full-field.nc",
+ "scaling-residual.nc",
+ "time-mean.nc",
+ ):
+ copy(stats_combined_dir + filename, tmpdir + "/" + filename)
+ runs = [run for run in config.runs if run not in config.stats.exclude_runs]
+ run_names = ", ".join(runs)
+ if config.stats.start_date is None:
+ start = "start of run"
+ else:
+ start = config.stats.start_date
+ if config.stats.end_date is None:
+ end = "end of run"
+ else:
+ end = config.stats.end_date
+ beaker.dataset.create(
+ config.stats.beaker_dataset,
+ tmpdir,
+ workspace="ai2/ace",
+ description=(
+ "Coefficients for normalization for data "
+ f"{config.data_output_directory} runs {run_names}. "
+ f"Computed from {start} to {end}."
+ ),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py b/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py
index 0826b98..33fa055 100644
--- a/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py
+++ b/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py
@@ -223,7 +223,7 @@ def create_record(
def create_record_local(
- item: Tuple[pd.Timestamp, pd.Timestamp, str, str, str]
+ item: Tuple[pd.Timestamp, pd.Timestamp, str, str, str],
) -> Generator[Tuple[xbeam.Key, xr.Dataset], None, None]:
start_time, end_time, variable, category, path = item
with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/scripts/era5/pipeline/xr-beam-pipeline.py b/scripts/era5/pipeline/xr-beam-pipeline.py
index bfdd35f..41fbf12 100644
--- a/scripts/era5/pipeline/xr-beam-pipeline.py
+++ b/scripts/era5/pipeline/xr-beam-pipeline.py
@@ -2,6 +2,7 @@
import datetime
import logging
import os
+from typing import Sequence
import apache_beam as beam
import metview
@@ -61,8 +62,12 @@ def grid_attribute_fix(ds, names_to_fix, reference_name):
URL_GOOGLE_LATLON = (
"gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
)
+# following dataset was manually generated in https://github.com/ai2cm/explore/blob/master/oliwm/2024-07-31-generate-ERA5-co2.ipynb # noqa: E501
+URL_CO2 = "gs://vcm-ml-raw-flexible-retention/2024-11-11-co2-annual-mean-for-era5.zarr"
-URL_NCAR_ERA5 = "gs://vcm-ml-intermediate/2024-05-17-era5-025deg-2D-variables-from-NCAR-as-zarr" # noqa: E501
+URL_NCAR_ERA5 = (
+ "gs://vcm-ml-intermediate/2024-05-17-era5-025deg-2D-variables-from-NCAR-as-zarr" # noqa: E501
+)
URL_INVARIANT = f"{URL_NCAR_ERA5}/e5.oper.invariant.zarr"
URL_SURFACE_ANALYSIS_LATLON = f"{URL_NCAR_ERA5}/e5.oper.an.sfc.zarr"
URL_MEAN_FLUX = f"{URL_NCAR_ERA5}/e5.oper.fc.sfc.meanflux.zarr"
@@ -198,13 +203,6 @@ def grid_attribute_fix(ds, names_to_fix, reference_name):
RENAME_Z_PRES = {
f"geopotential_{p}": f"h{p}" for p in OUTPUT_PRESSURE_LEVELS_GEOPOTENTIAL
}
-RENAME_ETC = {
- "skt": "surface_temperature",
- "t2m": "TMP2m",
- "u10": "UGRD10m",
- "v10": "VGRD10m",
- "d2m": "DPT2m",
-}
RENAME_PRESSURE_LEVEL = {
**RENAME_Q_PRES,
**RENAME_T_PRES,
@@ -214,28 +212,26 @@ def grid_attribute_fix(ds, names_to_fix, reference_name):
}
-def set_nlayer_globals(output_layer_indices):
- # set global vars that depend on the output layer indices
- global OUTPUT_LAYER_INDICES
- OUTPUT_LAYER_INDICES = output_layer_indices
- assert OUTPUT_LAYER_INDICES[-1] == N_INPUT_LAYERS
-
- global N_OUTPUT_LAYERS
- N_OUTPUT_LAYERS = len(OUTPUT_LAYER_INDICES) - 1
-
- RENAME_Q = {f"q_{i}": f"specific_total_water_{i}" for i in range(N_OUTPUT_LAYERS)}
- RENAME_T = {f"t_{i}": f"air_temperature_{i}" for i in range(N_OUTPUT_LAYERS)}
- RENAME_U = {f"u_{i}": f"eastward_wind_{i}" for i in range(N_OUTPUT_LAYERS)}
- RENAME_V = {f"v_{i}": f"northward_wind_{i}" for i in range(N_OUTPUT_LAYERS)}
-
- global RENAME_NATIVE
- RENAME_NATIVE = {
- **RENAME_Q,
- **RENAME_T,
- **RENAME_U,
- **RENAME_V,
- **RENAME_ETC,
+def _get_native_rename_dict(n_output_layers):
+ rename_q = {f"q_{i}": f"specific_total_water_{i}" for i in range(n_output_layers)}
+ rename_t = {f"t_{i}": f"air_temperature_{i}" for i in range(n_output_layers)}
+ rename_u = {f"u_{i}": f"eastward_wind_{i}" for i in range(n_output_layers)}
+ rename_v = {f"v_{i}": f"northward_wind_{i}" for i in range(n_output_layers)}
+ rename_etc = {
+ "skt": "surface_temperature",
+ "t2m": "TMP2m",
+ "u10": "UGRD10m",
+ "v10": "VGRD10m",
+ "d2m": "DPT2m",
+ }
+ rename_native = {
+ **rename_q,
+ **rename_t,
+ **rename_u,
+ **rename_v,
+ **rename_etc,
}
+ return rename_native
def _open_zarr(key, sel_indices):
@@ -290,6 +286,17 @@ def open_quarter_degree_datasets(indices) -> xr.Dataset:
return sfc, invariant
+def open_co2_dataset(start_time, end_time) -> xr.Dataset:
+ co2 = xr.open_zarr(URL_CO2, chunks=None)
+ ds_start = pd.Timestamp(co2.time.values[0])
+ ds_stop = pd.Timestamp(co2.time.values[-1])
+ assert start_time >= ds_start, f"CO2 dataset time start out of bounds"
+ assert end_time <= ds_stop, f"CO2 dataset time stop out of bounds"
+ co2 = co2.sel(time=slice(start_time, end_time))
+ co2 = co2.load()
+ return co2
+
+
def _to_dataset(fs: metview.Fieldset) -> xr.Dataset:
return fs.to_dataset().load()
@@ -375,7 +382,12 @@ def process_pressure_level_data(key, ds, output_grid=DEFAULT_OUTPUT_GRID):
return new_key, output
-def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset:
+def _process_native_data(
+ ds: xr.Dataset, output_grid: str, output_layer_indices: Sequence[int]
+) -> xr.Dataset:
+ n_output_layers = len(output_layer_indices) - 1
+ rename_dict = _get_native_rename_dict(n_output_layers)
+
xr.set_options(keep_attrs=True)
# singleton time dimension interferes with metview
ds = ds.squeeze()
@@ -405,15 +417,15 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset:
thicknesses = _to_dataarray(thicknesses_fs, "pres")
for short_name in ["q", "t", "u", "v"]:
variable = _to_dataarray(fieldset_gg.select(shortName=short_name), short_name)
- for output_index in range(N_OUTPUT_LAYERS): # type: ignore[name-defined]
+ for output_index in range(n_output_layers):
logging.info(
f"Computing vertical integral of {short_name} "
f"for output layer {output_index}."
)
fine_levels = slice(
- OUTPUT_LAYER_INDICES[output_index], # type: ignore[name-defined]
- OUTPUT_LAYER_INDICES[output_index + 1], # type: ignore[name-defined]
+ output_layer_indices[output_index],
+ output_layer_indices[output_index + 1],
)
coarse_level_thicknesses = thicknesses.isel(hybrid=fine_levels)
total_thickness = coarse_level_thicknesses.sum("hybrid")
@@ -455,7 +467,7 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset:
output = _adjust_latlon(output)
- output = output.rename(RENAME_NATIVE) # type: ignore[name-defined]
+ output = output.rename(rename_dict)
for name, attrs in DESIRED_ATTRS.items():
if name in output:
output[name] = output[name].assign_attrs(**attrs)
@@ -481,8 +493,13 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset:
return output
-def process_native_data(key, ds, output_grid=DEFAULT_OUTPUT_GRID):
- output = _process_native_data(ds, output_grid)
+def process_native_data(
+ key,
+ ds,
+ output_grid=DEFAULT_OUTPUT_GRID,
+ output_layer_indices=DEFAULT_OUTPUT_LAYER_INDICES,
+):
+ output = _process_native_data(ds, output_grid, output_layer_indices)
new_key = key.replace(
offsets={"time": key.offsets["time"], "latitude": 0, "longitude": 0},
vars=frozenset(output.keys()),
@@ -658,7 +675,9 @@ def process_quarter_degree_data_sfc_an(
return new_key, output_ds
-def _get_vertical_coordinate(ds: xr.Dataset, name: str) -> xr.Dataset:
+def _get_vertical_coordinate(
+ ds: xr.Dataset, name: str, output_layer_indices: Sequence[int]
+) -> xr.Dataset:
"""Get the ak/bk vertical coordinate on coarse layer interfaces.
Assuming that ds[name] is a 3D variable which includes
@@ -667,8 +686,8 @@ def _get_vertical_coordinate(ds: xr.Dataset, name: str) -> xr.Dataset:
hybrid_sigma_values = ds[name].attrs["GRIB_pv"]
ak = hybrid_sigma_values[: N_INPUT_LAYERS + 1]
bk = hybrid_sigma_values[N_INPUT_LAYERS + 1 :]
- ak_coarse = [ak[i] for i in OUTPUT_LAYER_INDICES] # type: ignore[name-defined]
- bk_coarse = [bk[i] for i in OUTPUT_LAYER_INDICES] # type: ignore[name-defined]
+ ak_coarse = [ak[i] for i in output_layer_indices]
+ bk_coarse = [bk[i] for i in output_layer_indices]
ak_coarse_ds = xr.Dataset({f"ak_{i}": value for i, value in enumerate(ak_coarse)})
bk_coarse_ds = xr.Dataset({f"bk_{i}": value for i, value in enumerate(bk_coarse)})
for name in ak_coarse_ds:
@@ -684,10 +703,12 @@ def _make_template(
ds_quarter_degree_sfc,
ds_quarter_degree_invariant,
ds_google_latlon,
+ ds_co2,
ds_akbk,
output_chunks,
reuse_template,
output_grid,
+ output_layer_indices,
):
"""Here we (mostly) lazily process the data to make a reference zarr store
for the output. This function mirrors what the pipeline does."""
@@ -711,7 +732,9 @@ def _make_template(
ds_sfc_an_regridded = _process_quarter_degree_data_sfc_an(
ds_quarter_degree_sfc.isel(time=0), ds_invariant_regridded, output_grid
)
- ds_native_regridded = _process_native_data(ds_native.isel(time=0), output_grid)
+ ds_native_regridded = _process_native_data(
+ ds_native.isel(time=0), output_grid, output_layer_indices
+ )
ds_google_latlon_regridded = _process_pressure_level_data(
ds_google_latlon.isel(time=0), output_grid
)
@@ -743,6 +766,7 @@ def _make_template(
# land fraction and temporally variable sea ice fraction
# this will get written eagerly since it is not chunked
template = template.update(inv_fields)
+ template = template.update(ds_co2)
return template, inv_fields
@@ -804,9 +828,6 @@ def main():
parser = _get_parser()
args, pipeline_args = parser.parse_known_args()
- # set globals that depend on output layer indices
- set_nlayer_globals(args.output_layer_indices)
-
# desired start/end of output dataset, inclusive
start_time = datetime.datetime.strptime(args.start_time, "%Y-%m-%dT%H:%M:%S")
end_time = datetime.datetime.strptime(args.end_time, "%Y-%m-%dT%H:%M:%S")
@@ -855,8 +876,9 @@ def main():
sel_indices
)
ds_google_latlon = open_google_latlon_dataset(sel_indices)
+ ds_co2 = open_co2_dataset(start_time, end_time)
logging.info("Getting vertical coordinate")
- ds_akbk = _get_vertical_coordinate(ds_native, "t")
+ ds_akbk = _get_vertical_coordinate(ds_native, "t", args.output_layer_indices)
logging.info("Generating template")
template, ds_pt25deg_inv_regridded = _make_template(
@@ -865,10 +887,12 @@ def main():
ds_quarter_degree_sfc,
ds_quarter_degree_inv,
ds_google_latlon,
+ ds_co2,
ds_akbk,
output_chunks,
args.reuse_template,
args.output_grid,
+ args.output_layer_indices,
)
logging.info("Template finished generating. Starting pipeline.")
@@ -914,7 +938,11 @@ def main():
p
| "native_DatasetToChunks"
>> xbeam.DatasetToChunks(ds_native, chunks={"time": 1})
- | beam.MapTuple(process_native_data, output_grid=args.output_grid)
+ | beam.MapTuple(
+ process_native_data,
+ output_grid=args.output_grid,
+ output_layer_indices=args.output_layer_indices,
+ )
| "native_ConsolidateChunks" >> xbeam.ConsolidateChunks(output_chunks)
| "native_to_zarr"
>> xbeam.ChunksToZarr(args.output_path, template, output_chunks)
diff --git a/scripts/manual_backwards_compatibility/.gitignore b/scripts/manual_backwards_compatibility/.gitignore
new file mode 100644
index 0000000..1cf850e
--- /dev/null
+++ b/scripts/manual_backwards_compatibility/.gitignore
@@ -0,0 +1 @@
+test_inference_ace2_era5
\ No newline at end of file
diff --git a/scripts/manual_backwards_compatibility/ace2-era5.sh b/scripts/manual_backwards_compatibility/ace2-era5.sh
new file mode 100755
index 0000000..4768fc9
--- /dev/null
+++ b/scripts/manual_backwards_compatibility/ace2-era5.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+set -e
+
+# This script can be used to ensure that your currently installed version of the fme
+# package can do inference with the published ACE2-ERA5 model.
+
+# download necessary data
+mkdir -p test_inference_ace2_era5
+cd test_inference_ace2_era5
+mkdir -p initial_conditions
+mkdir -p forcing_data
+wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/ace2_era5_ckpt.tar?download=true -O ace2_era5_ckpt.tar
+wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/inference_config.yaml?download=true -O inference_config.yaml
+wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/initial_conditions/ic_2020.nc?download=true -O initial_conditions/ic_2020.nc
+wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/forcing_data/forcing_2020.nc?download=true -O forcing_data/forcing_2020.nc
+
+# update config to use relative paths and do a short run
+yq e '.n_forward_steps = 50' -i inference_config.yaml
+yq e '.forward_steps_in_memory = 5' -i inference_config.yaml
+yq e '.checkpoint_path = "ace2_era5_ckpt.tar"' -i inference_config.yaml
+yq e '.initial_condition.path = "initial_conditions/ic_2020.nc"' -i inference_config.yaml
+yq e '.forcing_loader.dataset.data_path = "forcing_data/"' -i inference_config.yaml
+
+# run on CPU or CUDA if the latter is available
+yq e '.experiment_dir = "output_cpu"' -i inference_config.yaml
+python -m fme.ace.inference inference_config.yaml
+
+# run on MPS. NOTE: this requires torch==2.5 otherwise there are complaints about some of the
+# features used by the SFNO architecture.
+yq e '.experiment_dir = "output_mps"' -i inference_config.yaml
+export FME_USE_MPS=1
+python -m fme.ace.inference inference_config.yaml
diff --git a/scripts/monthly_data/test_write_monthly_data.py b/scripts/monthly_data/test_write_monthly_data.py
index a27a638..3c2e44c 100644
--- a/scripts/monthly_data/test_write_monthly_data.py
+++ b/scripts/monthly_data/test_write_monthly_data.py
@@ -1,13 +1,15 @@
import pathlib
from typing import List
+import pytest
import xarray as xr
from write_monthly_data import Config, run
-from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig
+from fme.ace.data_loading.config import DataLoaderConfig
+from fme.ace.testing import DimSize, DimSizes
+from fme.ace.testing.fv3gfs_data import save_nd_netcdf
+from fme.core.dataset.config import XarrayDataConfig
from fme.core.logging_utils import LoggingConfig
-from fme.core.testing import DimSizes
-from fme.core.testing.fv3gfs_data import save_2d_netcdf
def write_ensemble_dataset(
@@ -18,7 +20,7 @@ def write_ensemble_dataset(
for i in range(n_members):
ensemble_dir = path / f"ic_{i:04d}"
ensemble_dir.mkdir(exist_ok=True)
- save_2d_netcdf(
+ save_nd_netcdf(
ensemble_dir / "data.nc",
dim_sizes,
names,
@@ -26,12 +28,14 @@ def write_ensemble_dataset(
)
-def test_write_monthly_data(tmp_path: pathlib.Path):
+def test_write_monthly_data(very_fast_only: bool, tmp_path: pathlib.Path):
+ if very_fast_only:
+ pytest.skip("Skipping non-fast tests")
all_names = ["a", "b"]
+ horizontal = [DimSize("grid_yt", 8), DimSize("grid_xt", 4)]
dim_sizes = DimSizes(
n_time=4 * 60,
- n_lat=4,
- n_lon=8,
+ horizontal=horizontal,
nz_interface=2,
)
n_members = 3
@@ -45,7 +49,7 @@ def test_write_monthly_data(tmp_path: pathlib.Path):
data_loader=DataLoaderConfig(
dataset=dataset,
batch_size=1,
- num_data_workers=1,
+ num_data_workers=0,
),
logging=LoggingConfig(
log_to_screen=True,
diff --git a/scripts/monthly_data/write_monthly_data.py b/scripts/monthly_data/write_monthly_data.py
index 39057d7..096415a 100644
--- a/scripts/monthly_data/write_monthly_data.py
+++ b/scripts/monthly_data/write_monthly_data.py
@@ -1,9 +1,8 @@
import argparse
import dataclasses
-import datetime
import logging
import os
-from typing import List
+from typing import List, Sequence, Tuple
import dacite
import torch.utils.data
@@ -11,24 +10,42 @@
import yaml
import fme.core.logging_utils as logging_utils
+from fme.ace.data_loading.batch_data import BatchData, default_collate
+from fme.ace.data_loading.config import DataLoaderConfig
from fme.ace.inference.data_writer.monthly import (
MonthlyDataWriter,
months_for_timesteps,
)
-from fme.ace.inference.derived_variables import compute_derived_quantities
-from fme.core.data_loading.config import DataLoaderConfig
-from fme.core.data_loading.data_typing import SigmaCoordinates
-from fme.core.data_loading.getters import get_datasets
-from fme.core.data_loading.requirements import DataRequirements
-from fme.core.data_loading.utils import BatchData
+from fme.ace.stepper import AtmosphericDeriveFn
+from fme.core.dataset.getters import get_datasets
+from fme.core.dataset.requirements import DataRequirements
+from fme.core.dataset.xarray import DatasetProperties
from fme.core.device import using_gpu
from fme.core.distributed import Distributed
from fme.core.logging_utils import LoggingConfig
+from fme.core.typing_ import TensorMapping
+
+
+@dataclasses.dataclass
+class CollateFn:
+ horizontal_dims: List[str]
+
+ def __call__(
+ self, samples: Sequence[Tuple[TensorMapping, xr.DataArray]]
+ ) -> "BatchData":
+ sample_data, sample_time = zip(*samples)
+ batch_data = default_collate(sample_data)
+ batch_time = xr.concat(sample_time, dim="sample")
+ return BatchData(
+ data=batch_data,
+ time=batch_time,
+ horizontal_dims=self.horizontal_dims,
+ )
def get_data_loaders(
config: DataLoaderConfig, requirements: DataRequirements
-) -> List[torch.utils.data.DataLoader]:
+) -> Tuple[List[torch.utils.data.DataLoader], DatasetProperties]:
dist = Distributed.get_instance()
if dist.world_size > 1:
raise RuntimeError(
@@ -36,7 +53,7 @@ def get_data_loaders(
"supported in distributed mode."
)
- datasets = get_datasets(config.dataset, requirements)
+ datasets, properties = get_datasets(config.dataset, requirements)
data_loaders = []
for dataset in datasets:
@@ -48,10 +65,12 @@ def get_data_loaders(
sampler=None,
drop_last=True,
pin_memory=using_gpu(),
- collate_fn=BatchData.from_sample_tuples,
+ collate_fn=CollateFn(
+ horizontal_dims=list(properties.horizontal_coordinates.dims),
+ ),
)
data_loaders.append(dataloader)
- return data_loaders
+ return data_loaders, properties
def get_timesteps(data_loaders: List[torch.utils.data.DataLoader]) -> int:
@@ -66,7 +85,7 @@ class Config:
"""
Configuration for applying the MonthlyDataWriter to a dataset.
- Attributes:
+ Parameters:
experiment_dir: Directory to save results to.
dataset: Configuration for the dataset to load.
num_data_workers: Number of parallel workers to use for data loading.
@@ -88,7 +107,7 @@ def __post_init__(self):
raise ValueError("Batch size must be 1 to write dataset using writer.")
def get_data(self) -> "Data":
- data_loaders = get_data_loaders(
+ data_loaders, properties = get_data_loaders(
config=self.data_loader,
requirements=DataRequirements(
names=self.variable_names,
@@ -98,8 +117,7 @@ def get_data(self) -> "Data":
n_timesteps = get_timesteps(data_loaders=data_loaders)
return Data(
loaders=data_loaders,
- sigma_coordinates=data_loaders[0].dataset.sigma_coordinates,
- timestep=data_loaders[0].dataset.timestep,
+ properties=properties,
n_timesteps=n_timesteps,
)
@@ -107,10 +125,10 @@ def configure_logging(self, log_filename: str):
self.logging.configure_logging(self.experiment_dir, log_filename)
def get_data_writer(self, data: "Data") -> MonthlyDataWriter:
- n_months = months_for_timesteps(data.n_timesteps, data.timestep)
+ n_months = months_for_timesteps(data.n_timesteps, data.properties.timestep)
coords = {
- **data.loaders[0].dataset.horizontal_coordinates.coords,
- **data.loaders[0].dataset.sigma_coordinates.coords,
+ **data.properties.horizontal_coordinates.coords,
+ **data.properties.vertical_coordinate.coords,
}
return MonthlyDataWriter(
path=self.experiment_dir,
@@ -118,7 +136,7 @@ def get_data_writer(self, data: "Data") -> MonthlyDataWriter:
save_names=None, # save all data given
n_samples=self.data_loader.batch_size * len(data.loaders),
n_months=n_months,
- metadata=data.loaders[0].dataset.metadata,
+ variable_metadata=data.properties.variable_metadata,
coords=coords,
)
@@ -126,8 +144,7 @@ def get_data_writer(self, data: "Data") -> MonthlyDataWriter:
@dataclasses.dataclass
class Data:
loaders: List[torch.utils.data.DataLoader]
- sigma_coordinates: SigmaCoordinates
- timestep: datetime.timedelta
+ properties: DatasetProperties
n_timesteps: int
@@ -135,12 +152,15 @@ def merge_loaders(loaders: List[torch.utils.data.DataLoader]):
window_batch_data_list: List[BatchData]
for window_batch_data_list in zip(*loaders):
tensors = [item.data for item in window_batch_data_list]
- times = [item.times for item in window_batch_data_list]
+ time = [item.time for item in window_batch_data_list]
window_batch_data = {
k: torch.concat([d[k] for d in tensors]) for k in tensors[0].keys()
}
- times = xr.concat(times, dim="sample")
- yield BatchData(data=window_batch_data, times=times)
+ time = xr.concat(time, dim="sample")
+ yield BatchData(
+ data=window_batch_data,
+ time=time,
+ )
def run(config: Config):
@@ -150,19 +170,25 @@ def run(config: Config):
data = config.get_data()
writer = config.get_data_writer(data)
+ derive_func = AtmosphericDeriveFn(
+ vertical_coordinate=data.properties.vertical_coordinate,
+ timestep=data.properties.timestep,
+ )
+
n_batches = len(data.loaders[0].dataset) // config.data_loader.batch_size
for i, window_batch_data in enumerate(merge_loaders(data.loaders)):
# no need to trim initial conditions because
# we set n_timesteps to 1 in the DataRequirements
assert list(window_batch_data.data.values())[0].shape[1] == 1
- window_batch_data.data = compute_derived_quantities(
- window_batch_data.data, data.sigma_coordinates, data.timestep
+ window_batch_data = window_batch_data.compute_derived_variables(
+ derive_func=derive_func,
+ forcing_data=window_batch_data,
)
writer.append_batch(
data=window_batch_data.data,
start_timestep=-1, # ignored
- batch_times=window_batch_data.times,
+ batch_time=window_batch_data.time,
)
if i % 10 == 0:
logging.info(f"Writing batch {i+1} of {n_batches}.")
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 4d358e7..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[flake8]
-exclude = docs
-ignore = E203,W293,W503,F541,E402
-max-line-length = 88