diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 710bf8f..ef6c605 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,17 +13,41 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Set environment variables + run: | + echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV - uses: actions/setup-python@v4 with: python-version: "3.10" - uses: actions/cache@v4 with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('fme/constraints.txt') }} + key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }} - name: Install dependencies run: | - python -m pip install uv==0.2.5 - uv pip install --system -c constraints.txt -e fme[dev] + python -m pip install uv + uv pip install --system -c constraints.txt -e ./fme[dev] - name: Run pytest run: | make test + cpu-very-fast: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set environment variables + run: | + echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - uses: actions/cache@v4 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }} + - name: Install dependencies + run: | + python -m pip install uv + uv pip install --system -c constraints.txt -e ./fme[dev] + - name: Run pytest + run: | + make test_very_fast diff --git a/.gitignore b/.gitignore index a155547..38980e2 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ dmypy.json # scratch directory for testing scratch/ + +# Some in progress data pipelines get added here +scripts/data_process/.nfs* diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index b9fb3f3..0000000 --- a/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -profile=black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e01376..55bdf31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,41 +1,24 @@ repos: -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black - additional_dependencies: ["click==8.0.4"] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.1 + hooks: + - id: ruff + args: ["--fix", "--config", "fme/pyproject.toml"] + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v5.0.0 hooks: - id: check-added-large-files args: [--maxkb=250] - id: trailing-whitespace - - id: flake8 - name: flake8 - language_version: python3 - exclude: "__init__.py" - args: [--config, setup.cfg] - - id: flake8 - name: flake8 __init__.py files - files: "__init__.py" - # ignore unused import error in __init__.py files - args: ["--ignore=F401,E203,W503", --config, setup.cfg] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.2.0 hooks: - id: mypy additional_dependencies: ["types-PyYaml==5.4.3"] - args: [ - --follow-imports, silent, --ignore-missing-imports - ] + args: ["--ignore-missing-imports", "--check-untyped-defs"] exclude: | (?x)^( .+/conf.py | - .+/setup.py | .+/conftest.py - )$ -- repo: https://github.com/pycqa/isort - rev: 5.11.5 - hooks: - - id: isort - name: isort (python) \ No newline at end of file + )$ \ No newline at end of file diff --git a/Makefile b/Makefile index 0ff16ed..adf7a36 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ ENVIRONMENT_NAME ?= fme DEPLOY_TARGET ?= pypi build_docker_image: - docker build -f docker/Dockerfile -t $(IMAGE):$(VERSION) . + docker build --platform=linux/amd64 -f docker/Dockerfile -t $(IMAGE):$(VERSION) . enter_docker_image: build_docker_image docker run -it --rm $(IMAGE):$(VERSION) bash @@ -12,11 +12,18 @@ enter_docker_image: build_docker_image # recommended to deactivate current conda environment before running this create_environment: conda create -n $(ENVIRONMENT_NAME) python=3.10 pip - conda run --no-capture-output -n $(ENVIRONMENT_NAME) python -m pip install uv==0.2.5 - conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -c constraints.txt -e fme[dev] + conda run --no-capture-output -n $(ENVIRONMENT_NAME) python -m pip install uv + conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -c constraints.txt -e ./fme[dev,docs] + conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -r analysis-deps.txt test: - pytest --durations 20 . + pytest --durations 40 . + +test_fast: + pytest --durations 40 --fast . + +test_very_fast: + pytest --durations 40 --very-fast . # For maintainer use only # requires fme[deploy] to be installed diff --git a/README.md b/README.md index f6f944b..2efacac 100644 --- a/README.md +++ b/README.md @@ -43,4 +43,4 @@ gs://ai2cm-public-requester-pays/2024-11-13-ai2-climate-emulator-v2-amip/data/er The dataset used in the [ACE2-SOM paper](https://arxiv.org/abs/2412.04418) is available at: ``` gs://ai2cm-public-requester-pays/2024-12-05-ai2-climate-emulator-v2-som/SHiELD-SOM-C96 -``` \ No newline at end of file +``` diff --git a/analysis-deps.txt b/analysis-deps.txt new file mode 100644 index 0000000..3edb535 --- /dev/null +++ b/analysis-deps.txt @@ -0,0 +1,13 @@ +# these are some packages which are convenient to have installed for ad-hoc analysis +# but which are not requirements of the "fme" package. We do not relist the fme +# dependencies here. +beaker-py +Bottleneck +cartopy>=0.22.0 +dask[distributed] +ipywidgets +nc-time-axis +jupyterlab +pyproj<3.7 +seaborn +bokeh>=3.1.0 \ No newline at end of file diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..265e2a4 --- /dev/null +++ b/conftest.py @@ -0,0 +1,68 @@ +import signal + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--fast", + action="store_true", + default=False, + help="Skip slow tests", + ) + parser.addoption( + "--very-fast", + action="store_true", + default=False, + help="Run only very fast tests (< 5 seconds)", + ) + + +@pytest.fixture +def skip_slow(request, very_fast_only): + return very_fast_only or request.config.getoption("--fast") + + +@pytest.fixture +def very_fast_only(request): + return request.config.getoption("--very-fast") + + +class TimeoutException(Exception): + pass + + +def timeout_handler(signum, frame): + raise TimeoutException("Test took too long") + + +@pytest.fixture +def pdb_enabled(request): + return request.config.getoption("--pdb") + + +@pytest.fixture(autouse=True, scope="function") +def enforce_timeout(skip_slow, very_fast_only, pdb_enabled): + if pdb_enabled: + yield # Do not enforce timeout if we are debugging + return + if very_fast_only: + timeout_seconds = 3 + elif skip_slow: + timeout_seconds = 30 + else: + timeout_seconds = 60 + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout_seconds) # Set the timeout for the test + try: + yield + finally: + signal.alarm(0) # Disable the alarm after the test completes + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_call(item): + try: + yield + except TimeoutException: + pytest.fail("Test failed due to timeout") diff --git a/fme/conftest.py b/fme/conftest.py deleted file mode 100644 index 475fb2d..0000000 --- a/fme/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest - - -def pytest_addoption(parser): - parser.addoption( - "--fast", action="store_true", default=False, help="Run only fast tests" - ) - - -@pytest.fixture -def skip_slow(request): - return request.config.getoption("--fast") diff --git a/fme/docs/_static/Ai2_icon_pink_RGB.png b/fme/docs/_static/Ai2_icon_pink_RGB.png new file mode 100644 index 0000000..e418a9d Binary files /dev/null and b/fme/docs/_static/Ai2_icon_pink_RGB.png differ diff --git a/fme/docs/_static/Ai2_icon_pink_padding_RGB.png b/fme/docs/_static/Ai2_icon_pink_padding_RGB.png new file mode 100644 index 0000000..d84ccef Binary files /dev/null and b/fme/docs/_static/Ai2_icon_pink_padding_RGB.png differ diff --git a/fme/docs/_static/custom.css b/fme/docs/_static/custom.css new file mode 100644 index 0000000..14354dc --- /dev/null +++ b/fme/docs/_static/custom.css @@ -0,0 +1,21 @@ +body[data-theme="dark"] { + --code-block-background: #202020; +} + +body[data-theme="light"] { + --code-block-background: #f8f9fb; +} + +body[data-theme="auto"] { + --code-block-background: #f8f9fb; +} + +@media (prefers-color-scheme: dark) { + body[data-theme="auto"] { + --code-block-background: #202020; + } +} + +div.highlight pre { + background: var(--code-block-background); +} \ No newline at end of file diff --git a/fme/docs/builder.rst b/fme/docs/builder.rst index 5a23d31..e32f7c0 100644 --- a/fme/docs/builder.rst +++ b/fme/docs/builder.rst @@ -51,7 +51,7 @@ Let's define a training configuration ``TrainConfig`` containing an ``OptimizerC """ Configuration for an optimizer. - Attributes: + Parameters: optimizer_type: The type of optimizer to use. lr: The learning rate. kwargs: Additional keyword arguments to pass to the optimizer. diff --git a/fme/docs/conf.py b/fme/docs/conf.py index eeae207..20f9794 100755 --- a/fme/docs/conf.py +++ b/fme/docs/conf.py @@ -24,7 +24,7 @@ sys.path.insert(0, os.path.abspath("..")) import fme # noqa -import fme.ace +import fme.core.registry # -- General configuration --------------------------------------------- @@ -33,7 +33,7 @@ # needs_sphinx = '1.0' # Fetch the dynamic data -module_types = fme.ace.get_available_module_types() +module_types = fme.core.registry.ModuleSelector.get_available_types() # Create a dynamic rst snippet that can be included in your documentation rst_snippet = f".. code-block:: text\n\n {module_types}" @@ -50,7 +50,15 @@ "sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx_autodoc_typehints", ] +autodoc_typehints = "description" + +# Intersphinx configuration to link to other projects' documentation +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), +} # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -93,6 +101,9 @@ # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" +# Include default values when documenting parameter types. +typehints_defaults = "comma" + # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -102,18 +113,37 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "sphinx_rtd_theme" +html_theme = "furo" +html_title = f"fme v{fme.__version__}" # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the # documentation. # -# html_theme_options = {} +html_theme_options = { + "light_logo": "Ai2_icon_pink_padding_RGB.png", + "dark_logo": "Ai2_icon_pink_padding_RGB.png", + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/ai2cm/ace", + "html": """ + + + + """, # noqa: E501 + "class": "", + }, + ], +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = [] +html_static_path = ["_static"] +html_css_files = ["custom.css"] + +html_favicon = "_static/Ai2_icon_pink_RGB.png" # -- Options for HTMLHelp output --------------------------------------- @@ -176,3 +206,9 @@ "Miscellaneous", ), ] + + +# -- Options for doctest ----------------------------------------------- +doctest_global_setup = """ +import fme +""" diff --git a/fme/docs/configs/explicit-indices.yaml b/fme/docs/configs/explicit-indices.yaml new file mode 100644 index 0000000..3080a5e --- /dev/null +++ b/fme/docs/configs/explicit-indices.yaml @@ -0,0 +1,3 @@ +path: initial_conditions.nc +start_indices: + list: [0, 3, 7] diff --git a/fme/docs/configs/inference-ic-indices.yaml b/fme/docs/configs/inference-ic-indices.yaml new file mode 100644 index 0000000..8c30681 --- /dev/null +++ b/fme/docs/configs/inference-ic-indices.yaml @@ -0,0 +1,5 @@ +path: initial_conditions.nc +start_indices: + n_initial_conditions: 3 + first: 1 + interval: 2 diff --git a/fme/docs/configs/timestamp-list.yaml b/fme/docs/configs/timestamp-list.yaml new file mode 100644 index 0000000..729f315 --- /dev/null +++ b/fme/docs/configs/timestamp-list.yaml @@ -0,0 +1,5 @@ +path: initial_conditions.nc +start_indices: + times: + - "2021-01-01T00:00:00" + - "2021-02-01T00:00:00" diff --git a/fme/docs/index.rst b/fme/docs/index.rst index 6657d3b..df4c0af 100644 --- a/fme/docs/index.rst +++ b/fme/docs/index.rst @@ -1,6 +1,9 @@ -Welcome to ACE's documentation! +fme: Full Model Emulation ====================================== +**fme** ("full model emulation") is a python package for training and running +climate model emulators, such as the Ai2 Climate Emulator. + .. toctree:: :maxdepth: 1 :caption: Contents: diff --git a/fme/docs/inference-config.yaml b/fme/docs/inference-config.yaml index 277ec47..e616530 100644 --- a/fme/docs/inference-config.yaml +++ b/fme/docs/inference-config.yaml @@ -1,19 +1,25 @@ experiment_dir: inference_output -n_forward_steps: 400 # 100 days +n_forward_steps: 400 # 100 days forward_steps_in_memory: 50 -checkpoint_path: ckpt.tar +checkpoint_path: ace_ckpt.tar logging: log_to_screen: true log_to_wandb: false log_to_file: true project: ace - entity: your_wandb_entity initial_condition: - path: initial_condition/data.nc + path: climSST/ic_2021.zarr + start_indices: + n_initial_conditions: 2 + first: 0 + interval: 3 + engine: zarr forcing_loader: dataset: - data_path: forcing - num_data_workers: 8 + data_path: climSST + file_pattern: forcing_2021.zarr + engine: zarr + n_repeats: 2 # use this to extend the 1-year of forcing data to desired length + num_data_workers: 2 data_writer: save_prediction_files: false - save_monthly_files: false diff --git a/fme/docs/inference_config.rst b/fme/docs/inference_config.rst index dc8612d..5f1a98d 100644 --- a/fme/docs/inference_config.rst +++ b/fme/docs/inference_config.rst @@ -11,25 +11,30 @@ The example assumes you are running in a directory structure like: :: . - ├── ckpt.tar - ├── initial_condition - │ └── data.nc # name must reflect the path in the config - └── forcing - ├── data1.nc # files can have any name, but must sort into time-sequential order - ├── data2.nc # can have any number of netCDF files - └── ... + ├── ace_ckpt.tar + ├── climSST + │ ├── forcing_2021.zarr + │ ├── ic_2021-01-01.zarr + │ └── ic_2021.zarr + └── inference-config.yaml -The ``.nc`` files correspond to data files like ``2021010100.nc`` in the `Zenodo repository`_, while ``ckpt.tar`` corresponds to a file like ``ace_ckpt.tar`` in that repository. +that includes a model checkpoint (``ace_ckpt.tar``), forcing data (``forcing_2021.zarr``), and an initial condition (e.g., ``ic_2021-01-01.zarr``). You can find the forcing and initial condition data in the `Zenodo repository`_. The specified initial condition file should contain a time dimension of at least length 1, but can also -contain multiple times. If multiple times are present and `start_indices` is not specified in the -configuration, the inference will run an ensemble using all times in the initial condition file. -Selections from initial conditions can be made using the `start_indices` parameter in the configuration. +contain multiple times. If multiple times are present and ``start_indices`` is not specified in the +:class:`fme.ace.InitialConditionConfig` configuration, the inference will run an ensemble using all times +in the initial condition file. The ``ic_2021.zarr`` file is an example of a file with multiple times, containing +initial conditions for each month of 2021. For examples of selecting specific initial +conditions, see :ref:`initial-condition-examples`. -While netCDF files are specified in the example, Zarr datasets are also compatible. E.g., -specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset configuration. +While Zarr files are specified in the example, netCDFs are also compatible. E.g., +specifying the parent folder with the netCDF files as the ``path`` setting ``engine`` to ``netcdf4`` +in the dataset configuration. See :class:`fme.ace.XarrayDataConfig` for an example. -.. _Zenodo repository: https://zenodo.org/doi/10.5281/zenodo.10791086 +Example YAML Configuration +--------------------------- + +.. _Zenodo repository: https://zenodo.org/records/13787710 .. literalinclude:: inference-config.yaml :language: yaml @@ -52,9 +57,9 @@ specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset ) # these paths are used in the documentation on this page # if they change then update the docs! - assert config.checkpoint_path == "ckpt.tar" - assert config.initial_condition.path == "initial_condition/data.nc" - assert config.forcing_loader.dataset.data_path == "forcing" + assert config.checkpoint_path == "ace_ckpt.tar" + assert config.initial_condition.path == "climSST/ic_2021.zarr" + assert config.forcing_loader.dataset.data_path == "climSST" print("Loaded successfully") .. testoutput:: @@ -62,6 +67,9 @@ specifying `data.zarr` as the `path` setting `engine` to `zarr` in the dataset Loaded successfully +Configuration structure +----------------------- + We use the :ref:`Builder pattern ` to load this configuration into a multi-level dataclass structure. The configuration is divided into several sub-configurations, each with its own dataclass. The top-level configuration is the :class:`fme.ace.InferenceConfig` class. @@ -99,3 +107,89 @@ The sub-configurations are: .. autoclass:: fme.ace.OceanConfig :show-inheritance: :noindex: + + + .. _initial-condition-examples: + +:class:`fme.ace.InitialConditionConfig` Examples +------------------------------------------------- + +The ``start_indices`` attribute can be used to specify which initial conditions +to use when multiple are present in the dataset (instead of using all available). +The following examples show example selections using the yaml builder pattern for +an ``InitialConditionConfig``. + + +:class:`fme.ace.InferenceInitialConditionIndices` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Select a number of regularly spaced initial conditions. + +.. literalinclude:: configs/inference-ic-indices.yaml + :language: yaml + +.. testcode:: + :hide: + + from fme.ace import InitialConditionConfig + import dacite + import yaml + + with open('configs/inference-ic-indices.yaml', 'r') as f: + config_dict = yaml.safe_load(f) + + config = dacite.from_dict( + InitialConditionConfig, + data=config_dict, + config=dacite.Config(strict=True) + ) + assert config.start_indices.n_initial_conditions == 3 + +:class:`fme.ace.TimestampList` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Selecting two timestamps from the initial conditions. + +.. literalinclude:: configs/timestamp-list.yaml + :language: yaml + +.. testcode:: + :hide: + + from fme.ace import InitialConditionConfig + import dacite + import yaml + + with open('configs/timestamp-list.yaml', 'r') as f: + config_dict = yaml.safe_load(f) + + config = dacite.from_dict( + InitialConditionConfig, + data=config_dict, + config=dacite.Config(strict=True) + ) + assert config.start_indices.times[0] == '2021-01-01T00:00:00' + +:class:`fme.ace.ExplicitIndices` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Selecting specific indices from the initial conditions. + +.. literalinclude:: configs/explicit-indices.yaml + :language: yaml + +.. testcode:: + :hide: + + from fme.ace import InitialConditionConfig + import dacite + import yaml + + with open('configs/explicit-indices.yaml', 'r') as f: + config_dict = yaml.safe_load(f) + + config = dacite.from_dict( + InitialConditionConfig, + data=config_dict, + config=dacite.Config(strict=True) + ) + assert config.start_indices.list[1] == 3 + + diff --git a/fme/docs/installation.rst b/fme/docs/installation.rst index eb5b27e..335f58e 100644 --- a/fme/docs/installation.rst +++ b/fme/docs/installation.rst @@ -34,7 +34,7 @@ A make target is available to build a conda environment: make create_environment -This will create an environment named ``fme``, and should use the same package versions we have used in development. If you would like a different name, set the ENVIRONMENT_NAME variable: +This will create an environment named ``fme``. If you would like a different name, set the ENVIRONMENT_NAME variable: .. code-block:: shell diff --git a/fme/docs/modules.rst b/fme/docs/modules.rst index a3dd14e..a16b404 100644 --- a/fme/docs/modules.rst +++ b/fme/docs/modules.rst @@ -15,7 +15,7 @@ The following module types are available: .. include:: available_modules.rst -.. autofunction:: fme.ace.get_available_module_types +.. autofunction:: fme.core.registry.ModuleSelector.get_available_types The following module builders are available: @@ -30,3 +30,9 @@ The following module builders are available: :undoc-members: :show-inheritance: :noindex: + +.. autoclass:: fme.ace.HEALPixRecUNetBuilder + :members: + :undoc-members: + :show-inheritance: + :noindex: \ No newline at end of file diff --git a/fme/docs/quickstart.rst b/fme/docs/quickstart.rst index 1488aca..64f1ba7 100644 --- a/fme/docs/quickstart.rst +++ b/fme/docs/quickstart.rst @@ -20,7 +20,7 @@ For the optional Weights and Biases (wandb) integration, you will need to set th export WANDB_API_KEY=wandb-api-key -where `wandb-api-key` is created and retrieved from the "API Keys" section of the `Wandb`_ settings page. +where ``wandb-api-key`` is created and retrieved from the "API Keys" section of the `Wandb`_ settings page. .. _Wandb: https://wandb.ai/settings @@ -58,7 +58,7 @@ For example, the 10-year validation data (approx. 190GB) can be downloaded with: gsutil -m -u YOUR_GCP_PROJECT cp -r gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/data/repeating-climSST-1deg-netCDFs/validation . -It is possible to download a portion of the dataset only, but it is necessary to have enough data to span the desired prediction period. The checkpoint is also available on GCS at `gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/checkpoints/ace_ckpt.tar`. +It is possible to download a portion of the dataset only, but it is necessary to have enough data to span the desired prediction period. The checkpoint is also available on GCS at ``gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/checkpoints/ace_ckpt.tar``. .. _Zenodo repository: https://zenodo.org/doi/10.5281/zenodo.10791086 .. _requester pays: https://cloud.google.com/storage/docs/requester-pays @@ -83,7 +83,8 @@ If you run into configuration issues, you can validate your configuration with While inference can be performed without a GPU, it may be very slow. If running on a Mac, set the environmental variable ``export FME_USE_MPS=1`` to enable using the `Metal Performance Shaders`_ framework for GPU acceleration. Note this backend is - not fully featured and it may not work with all inference features or for training. + not fully featured and it may not work with all inference features or for training. It is recommended to use the latest version + of torch if using MPS. .. _Metal Performance Shaders: https://developer.apple.com/metal/pytorch/ @@ -171,7 +172,7 @@ Then in the ``fme`` conda environment, run evaluation with: torchrun --nproc_per_node RANK_COUNT -m fme.ace.train config-train.yaml -where RANK_COUNT is how many processors you want to run on. +where ``RANK_COUNT`` is how many processors you want to run on. This will typically be the number of GPUs you have available. If running on a single GPU, you can omit the `torchrun` command and use ``python -m`` instead. diff --git a/fme/docs/requirements.txt b/fme/docs/requirements.txt index 6deb546..41b311d 100644 --- a/fme/docs/requirements.txt +++ b/fme/docs/requirements.txt @@ -1,2 +1,3 @@ +furo==2024.04.27 sphinx==7.0.0 -sphinx-rtd-theme==2.0.0 +sphinx_autodoc_typehints \ No newline at end of file diff --git a/fme/fme/ace/__init__.py b/fme/fme/ace/__init__.py index 8d30ba0..27be7f1 100644 --- a/fme/fme/ace/__init__.py +++ b/fme/fme/ace/__init__.py @@ -1,5 +1,16 @@ import sys +from fme.ace.data_loading.inference import ( + ExplicitIndices, + InferenceInitialConditionIndices, + TimestampList, +) +from fme.ace.data_loading.perturbation import ( + ConstantConfig, + GreensFunctionConfig, + PerturbationSelector, + SSTPerturbation, +) from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig from fme.ace.inference.evaluator import ( DataWriterConfig, @@ -16,22 +27,31 @@ InitialConditionConfig, run_inference_from_config, ) -from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder -from fme.core.corrector import CorrectorConfig -from fme.core.data_loading.config import TimeSlice, XarrayDataConfig -from fme.core.data_loading.inference import ( - ExplicitIndices, - InferenceInitialConditionIndices, - TimestampList, +from fme.ace.models.healpix.healpix_activations import ( + CappedGELUConfig, + DownsamplingBlockConfig, ) +from fme.ace.models.healpix.healpix_blocks import ConvBlockConfig, RecurrentBlockConfig +from fme.ace.registry.hpx import ( + HEALPixRecUNetBuilder, + UNetDecoderConfig, + UNetEncoderConfig, +) +from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder +from fme.core.corrector.corrector import CorrectorConfig +from fme.core.corrector.ocean import OceanCorrectorConfig +from fme.core.dataset.config import TimeSlice, XarrayDataConfig +from fme.core.gridded_ops import GriddedOperations from fme.core.loss import WeightedMappingLossConfig -from fme.core.normalizer import FromStateNormalizer, NormalizationConfig +from fme.core.normalizer import NormalizationConfig from fme.core.ocean import SlabOceanConfig from fme.core.optimization import SchedulerConfig from fme.core.parameter_init import FrozenParameterConfig, ParameterInitializationConfig -from fme.core.registry import ModuleSelector, get_available_module_types, register +from fme.core.registry.corrector import CorrectorSelector +from fme.core.registry.module import ModuleSelector +from fme.core.typing_ import Slice -from .train.train import run_train_from_config +from .train.train import run_train from .train.train_config import ( CopyWeightsConfig, DataLoaderConfig, @@ -41,7 +61,6 @@ LoggingConfig, OptimizationConfig, SingleModuleStepperConfig, - Slice, TrainConfig, ) diff --git a/fme/fme/core/aggregator/__init__.py b/fme/fme/ace/aggregator/__init__.py similarity index 100% rename from fme/fme/core/aggregator/__init__.py rename to fme/fme/ace/aggregator/__init__.py diff --git a/fme/fme/core/aggregator/inference/__init__.py b/fme/fme/ace/aggregator/inference/__init__.py similarity index 100% rename from fme/fme/core/aggregator/inference/__init__.py rename to fme/fme/ace/aggregator/inference/__init__.py diff --git a/fme/fme/core/aggregator/inference/annual.py b/fme/fme/ace/aggregator/inference/annual.py similarity index 89% rename from fme/fme/core/aggregator/inference/annual.py rename to fme/fme/ace/aggregator/inference/annual.py index d3c736c..7e51dfb 100644 --- a/fme/fme/core/aggregator/inference/annual.py +++ b/fme/fme/ace/aggregator/inference/annual.py @@ -1,30 +1,30 @@ import dataclasses import datetime -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple import numpy as np import torch import xarray as xr from matplotlib.figure import Figure -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.device import get_device from fme.core.distributed import Distributed -from fme.core.metrics import weighted_mean +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorMapping class GlobalMeanAnnualAggregator: def __init__( self, - area_weights: torch.Tensor, + ops: GriddedOperations, timestep: datetime.timedelta, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, monthly_reference_data: Optional[xr.Dataset] = None, ): - self.area_weights = area_weights + self._area_weighted_mean = ops.area_weighted_mean self.timestep = timestep - self.metadata = metadata + self.variable_metadata = variable_metadata self._target_datasets: Optional[List[xr.Dataset]] = None self._gen_datasets: Optional[List[xr.Dataset]] = None self._monthly_reference_data = monthly_reference_data @@ -37,7 +37,7 @@ def _get_reference(self, name: str) -> Optional["VariableReferenceData"]: if name not in self._monthly_reference_data: return None self._variable_reference_data[name] = process_monthly_reference( - self._monthly_reference_data, self.area_weights, name + self._monthly_reference_data, self._area_weighted_mean, name ) return self._variable_reference_data[name] @@ -51,12 +51,10 @@ def record_batch( """Record a batch of data for computing time variability statistics.""" target_data_area_mean, gen_data_area_mean = {}, {} for name in gen_data.keys(): - target_data_area_mean[name] = weighted_mean( - target_data[name], self.area_weights, dim=(-1, -2) - ).cpu() - gen_data_area_mean[name] = weighted_mean( - gen_data[name], self.area_weights, dim=(-1, -2) + target_data_area_mean[name] = self._area_weighted_mean( + target_data[name] ).cpu() + gen_data_area_mean[name] = self._area_weighted_mean(gen_data[name]).cpu() target_ds = to_dataset(target_data_area_mean, time) gen_ds = to_dataset(gen_data_area_mean, time) @@ -190,6 +188,19 @@ def get_logs(self, label: str) -> Dict[str, Any]: logs.update({f"{label}{name}": metrics[name] for name in metrics.keys()}) return logs + def get_dataset(self) -> xr.Dataset: + gathered = self._get_gathered_means() + if gathered is None: + return xr.Dataset() + target, gen = gathered + return xr.concat( + [ + target.expand_dims({"source": ["target"]}), + gen.expand_dims({"source": ["prediction"]}), + ], + dim="source", + ) + @dataclasses.dataclass class VariableReferenceData: @@ -199,14 +210,12 @@ class VariableReferenceData: def process_monthly_reference( - monthly_reference_data: xr.Dataset, area_weights: torch.Tensor, name: str + monthly_reference_data: xr.Dataset, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + name: str, ) -> VariableReferenceData: ref_global_mean = xr.DataArray( - weighted_mean( - torch.as_tensor(monthly_reference_data[name].values), - weights=area_weights.cpu(), - dim=(-1, -2), - ), + area_weighted_mean(torch.as_tensor(monthly_reference_data[name].values)), dims=monthly_reference_data[name].dims[:-2], coords={"time": monthly_reference_data[name].coords["time"]}, ) diff --git a/fme/fme/core/aggregator/inference/enso/__init__.py b/fme/fme/ace/aggregator/inference/enso/__init__.py similarity index 100% rename from fme/fme/core/aggregator/inference/enso/__init__.py rename to fme/fme/ace/aggregator/inference/enso/__init__.py diff --git a/fme/fme/core/aggregator/inference/enso/enso.py b/fme/fme/ace/aggregator/inference/enso/enso.py similarity index 73% rename from fme/fme/core/aggregator/inference/enso/enso.py rename to fme/fme/ace/aggregator/inference/enso/enso.py index cf14f04..440841c 100644 --- a/fme/fme/core/aggregator/inference/enso/enso.py +++ b/fme/fme/ace/aggregator/inference/enso/enso.py @@ -3,14 +3,15 @@ import cftime import matplotlib.pyplot as plt +import numpy as np import torch import xarray as xr -from fme.core.aggregator.plotting import get_cmap_limits, plot_imshow, plot_paneled_data -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.aggregator.plotting import get_cmap_limits, plot_imshow, plot_paneled_data +from fme.core.dataset.data_typing import VariableMetadata from fme.core.device import get_device from fme.core.distributed import Distributed -from fme.core.metrics import root_mean_squared_error +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import WandB @@ -26,6 +27,7 @@ def index_data_array( Args: index_data: List of (time, index) tuples. + calendar: Calendar for the time coordinate. Returns: ENSO index data as an xarray DataArray. @@ -73,31 +75,31 @@ class EnsoCoefficientEvaluatorAggregator: data are time-aggregated. Args: - initial_times: Initial times for each sample. + initial_time: Initial time for each sample. n_forward_timesteps: Number of timesteps for each sample. timestep: Timestep duration. - area_weights: Area weights for spatial averaging. - metadata: Metadata for the variables in the data. + gridded_operations: GriddedOperations instance for area-weighted RMSE. + variable_metadata: Metadata for the variables in the data. """ def __init__( self, - initial_times: xr.DataArray, + initial_time: xr.DataArray, n_forward_timesteps: int, timestep: datetime.timedelta, - area_weights: torch.Tensor, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + gridded_operations: GriddedOperations, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): - self._sample_index_series: List[ - Optional[xr.DataArray] - ] = get_sample_index_series( - self.enso_index, initial_times, n_forward_timesteps, timestep + self._sample_index_series: List[Optional[xr.DataArray]] = ( + get_sample_index_series( + self.enso_index, initial_time, n_forward_timesteps, timestep + ) ) - self._area_weights: torch.Tensor = area_weights - if metadata is not None: - self._metadata: Mapping[str, VariableMetadata] = metadata + self._ops = gridded_operations + if variable_metadata is not None: + self._variable_metadata: Mapping[str, VariableMetadata] = variable_metadata else: - self._metadata = {} + self._variable_metadata = {} n_samples = len(self._sample_index_series) self._target_covariances: List[TensorDict] = [{} for _ in range(n_samples)] self._gen_covariances: List[TensorDict] = [{} for _ in range(n_samples)] @@ -129,38 +131,36 @@ def record_batch( ), "number of index series must match number of samples" for i_sample, sample_index_series in enumerate(self._sample_index_series): if sample_index_series is not None: - sample_index_series_reindexed = sample_index_series.reindex( - time=time.isel(sample=i_sample), method="nearest" - ).values - sample_index_series_reindexed = torch.tensor( - sample_index_series_reindexed, + sample_index_series_window = sample_index_series.sel( + time=time.isel(sample=i_sample) + ) + sample_index_series_window = torch.tensor( + sample_index_series_window.values, device=get_device(), dtype=torch.float32, ) - self._index_variance[i_sample] += ( - sample_index_series_reindexed**2 - ).sum() + self._index_variance[i_sample] += (sample_index_series_window**2).sum() for name, data in target_data.items(): if name not in self._target_covariances[i_sample]: - self._target_covariances[i_sample][ - name - ] = data_index_covariance( - data[i_sample, :], sample_index_series_reindexed + self._target_covariances[i_sample][name] = ( + data_index_covariance( + data[i_sample, :], sample_index_series_window + ) ) else: - self._target_covariances[i_sample][ - name - ] += data_index_covariance( - data[i_sample, :], sample_index_series_reindexed + self._target_covariances[i_sample][name] += ( + data_index_covariance( + data[i_sample, :], sample_index_series_window + ) ) for name, data in gen_data.items(): if name not in self._gen_covariances[i_sample]: self._gen_covariances[i_sample][name] = data_index_covariance( - data[i_sample, :], sample_index_series_reindexed + data[i_sample, :], sample_index_series_window ) else: self._gen_covariances[i_sample][name] += data_index_covariance( - data[i_sample, :], sample_index_series_reindexed + data[i_sample, :], sample_index_series_window ) def _compute_coefficients( @@ -247,9 +247,9 @@ def get_logs(self, label: str) -> Dict[str, Any]: wandb = WandB.get_instance() images, metrics = {}, {} for name in gen_coefficients.keys(): - if name in self._metadata: - caption_name = self._metadata[name].long_name - caption_units = self._metadata[name].units + if name in self._variable_metadata: + caption_name = self._variable_metadata[name].long_name + caption_units = self._variable_metadata[name].units else: caption_name = name caption_units = "unknown units" @@ -267,10 +267,9 @@ def get_logs(self, label: str) -> Dict[str, Any]: ) images.update({f"coefficient_maps/{name}": coefficient_map}) rmse = float( - root_mean_squared_error( + self._ops.area_weighted_rmse( predicted=gen_coefficients[name], truth=target_coefficients[name], - weights=self._area_weights, ) .cpu() .numpy() @@ -299,10 +298,45 @@ def get_logs(self, label: str) -> Dict[str, Any]: logs.update({f"{label}{name}": metrics[name] for name in metrics.keys()}) return logs + def get_dataset(self) -> xr.Dataset: + """Get the coefficients as an xarray Dataset.""" + target_coefficients, gen_coefficients = self._get_coefficients() + if target_coefficients is None or gen_coefficients is None: + return xr.Dataset() + target_coefficients_ds = xr.Dataset( + { + name: ( + ["lat", "lon"], + target_coefficients[name].cpu().numpy(), + self._get_var_attrs(name), + ) + for name in target_coefficients.keys() + } + ).expand_dims({"source": ["target"]}) + gen_coefficients_ds = xr.Dataset( + { + name: (["lat", "lon"], gen_coefficients[name].cpu().numpy()) + for name in gen_coefficients.keys() + } + ).expand_dims({"source": ["prediction"]}) + return xr.concat([target_coefficients_ds, gen_coefficients_ds], dim="source") + + def _get_var_attrs(self, name: str) -> Dict[str, str]: + if name in self._variable_metadata: + attrs_name = self._variable_metadata[name].long_name + attrs_units = self._variable_metadata[name].units + else: + attrs_name = name + attrs_units = "unknown units" + return { + "long_name": f"{attrs_name} regression coefficient with Nino3.4 index", + "units": f"{attrs_units} / K", + } + def get_sample_index_series( index_data: xr.DataArray, - initial_times: xr.DataArray, + initial_time: xr.DataArray, n_forward_timesteps: int, timestep: datetime.timedelta, overlap_threshold: float = OVERLAP_THRESHOLD, @@ -312,7 +346,7 @@ def get_sample_index_series( Args: index_data: ENSO index data with a time coordinate. - initial_times: Initial times for each sample. + initial_time: Initial time for each sample. n_forward_timesteps: Number of forward timesteps for each sample. timestep: Timestep duration. overlap_threshold: Required overlap of reference index with inference period. @@ -321,38 +355,54 @@ def get_sample_index_series( List of zero-mean index series for each sample, or None if the sample does not overlap sufficiently with the reference index. """ - data_calendar = initial_times.dt.calendar + data_calendar = initial_time.dt.calendar index_calendar = index_data.time.dt.calendar if data_calendar != index_calendar: index_data = index_data.convert_calendar( calendar=data_calendar, dim="time", use_cftime=True ) sample_index_series: List[Optional[xr.DataArray]] = [] - for initial_time in initial_times: + for initial_time_sample in initial_time: duration = n_forward_timesteps * timestep - end_time = initial_time + duration + end_time = initial_time_sample + duration # select index data that overlaps with the inference period, plus a # half-timestep buffer since we will later reindex with nearest neighbor index_timestep_seconds = ( index_data.time[1].item() - index_data.time[0].item() ).total_seconds() half_index_timestep = datetime.timedelta(seconds=index_timestep_seconds / 2) - index_series = index_data.sel( + sample_index_data_selection = index_data.sel( time=slice( - initial_time - half_index_timestep, + initial_time_sample - half_index_timestep, end_time + half_index_timestep, ) ) - if index_series.sizes["time"] == 0: + if sample_index_data_selection.sizes["time"] == 0: + # no overlap sample_index_series.append(None) else: - index_series_duration = ( - index_series.time[-1].item() - index_series.time[0].item() + sample_time = xr.cftime_range( + start=initial_time_sample.item(), + end=end_time.item(), + freq=f"{int(timestep.total_seconds())}s", + calendar=data_calendar, ) - if index_series_duration > overlap_threshold * duration: - index_series = index_series - index_series.mean() - sample_index_series.append(index_series) + valid_sample_time = sample_time.where( + np.logical_and( + sample_time >= sample_index_data_selection.time[0], + sample_time <= sample_index_data_selection.time[-1], + ), + ).dropna() + if len(valid_sample_time) > len(sample_time) * overlap_threshold: + reindexed_series = sample_index_data_selection.reindex( + time=sample_time, method="nearest" + ) + reindexed_series_zero_mean = reindexed_series - reindexed_series.mean( + "time" + ) + sample_index_series.append(reindexed_series_zero_mean) else: + # insufficient overlap sample_index_series.append(None) return sample_index_series diff --git a/fme/fme/core/aggregator/inference/enso/index.py b/fme/fme/ace/aggregator/inference/enso/index.py similarity index 100% rename from fme/fme/core/aggregator/inference/enso/index.py rename to fme/fme/ace/aggregator/inference/enso/index.py diff --git a/fme/fme/core/aggregator/inference/enso/test_enso.py b/fme/fme/ace/aggregator/inference/enso/test_enso.py similarity index 73% rename from fme/fme/core/aggregator/inference/enso/test_enso.py rename to fme/fme/ace/aggregator/inference/enso/test_enso.py index 1ee1258..8bef935 100644 --- a/fme/fme/core/aggregator/inference/enso/test_enso.py +++ b/fme/fme/ace/aggregator/inference/enso/test_enso.py @@ -8,6 +8,7 @@ import xarray as xr from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from .enso import OVERLAP_THRESHOLD, EnsoCoefficientEvaluatorAggregator @@ -39,14 +40,14 @@ def _get_data( n_lon: int, calendar: str = "julian", ): - times = xr.cftime_range( + time = xr.cftime_range( start="2000-01-01", periods=n_times, freq="6h", calendar=calendar, ) enso_index = xr.DataArray( - _data_generator(scale, n_times), dims=["time"], coords={"time": times} + _data_generator(scale, n_times), dims=["time"], coords={"time": time} ) target_data = { # make target data perfectly correlated "a": torch.tile( @@ -60,17 +61,17 @@ def _get_data( [n_samples, 1, n_lat, n_lon], ).to(device=get_device()), } - sample_times = xr.concat( + sample_time = xr.concat( [ xr.DataArray( - times.values, + time.values, dims=["time"], ) for _ in range(n_samples) ], dim="sample", ) - return enso_index, sample_times, target_data, gen_data + return enso_index, sample_time, target_data, gen_data @pytest.mark.parametrize("scaling", [0.5, 1.0, 2.0]) @@ -81,10 +82,11 @@ def test_enso_coefficient_aggregator_values(scaling): Check that: - The aggregator maintains a zero-mean subset of the enso index for each sample. - - The target and gen coefficients are scaled versions of 1.0 and -1.0 for - perfectly correlated and anti-correlated data, respectively. - - The global-mean coefficient RMSE is scaled from 2.0 (perfect + - The logged target and gen coefficients are scaled versions of 1.0 and + -1.0 for perfectly correlated and anti-correlated data,respectively. + - The logged global-mean coefficient RMSE is scaled from 2.0 (perfect correlation minus perfect anti-correlation). + - The above two are also true via the `get_dataset` method. Args: scaling: How much to scale up or down the target and generated data; @@ -95,28 +97,31 @@ def test_enso_coefficient_aggregator_values(scaling): """ n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3 scale = 3 - area_weights = torch.ones([n_lat, n_lon]) + area_weights = torch.ones([n_lat, n_lon], device=get_device()) # get data that doesn't vary in space, but varies in time with the ENSO index - enso_index, sample_times, target_data, gen_data = _get_data( + enso_index, sample_time, target_data, gen_data = _get_data( scale, n_samples, n_times, n_lat, n_lon ) with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index): enso_agg = EnsoCoefficientEvaluatorAggregator( - initial_times=sample_times.isel(time=0), + initial_time=sample_time.isel(time=0), n_forward_timesteps=(n_times - 1), timestep=datetime.timedelta(hours=6), - area_weights=area_weights, + gridded_operations=LatLonOperations(area_weights), ) assert len(enso_agg._sample_index_series) == n_samples for index_values in enso_agg._sample_index_series: + assert index_values is not None # check that the index values are zero-mean for each sample assert np.isclose(index_values.mean().item(), 0.0) target_data["a"] *= scaling gen_data["a"] *= scaling - enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data) - enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data) + enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data) + enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data) coefficients = enso_agg._get_coefficients() target_coefficients, gen_coefficients = coefficients + assert target_coefficients is not None + assert gen_coefficients is not None # check that the target coefficients are 1.0 * scaling (perfectly correlated) assert torch.allclose(target_coefficients["a"], torch.tensor(scaling)) # check that the gen coefficients are -1.0 * scaling (perfectly anti-correlated) @@ -127,6 +132,16 @@ def test_enso_coefficient_aggregator_values(scaling): np.testing.assert_almost_equal( logs["enso_coefficients/rmse/a"], 2.0 * scaling, decimal=5 ) + enso_dataset = enso_agg.get_dataset() + # check that the coefficients are as expected in the dataset also + np.testing.assert_array_almost_equal( + enso_dataset["a"].sel(source="target").values, + target_coefficients["a"].cpu().numpy(), + ) + np.testing.assert_array_almost_equal( + enso_dataset["a"].sel(source="prediction").values, + gen_coefficients["a"].cpu().numpy(), + ) @pytest.mark.parametrize("shift", [1.5, 0.95, 0.05, 0.0]) @@ -138,27 +153,29 @@ def test_enso_index_inference_overlap(shift): n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3 data_scale = 3 area_weights = torch.ones([n_lat, n_lon]) - enso_index, sample_times, target_data, gen_data = _get_data( + enso_index, sample_time, target_data, gen_data = _get_data( data_scale, n_samples, n_times, n_lat, n_lon ) # shift the sample times so they only partially overlap the reference index index_duration = enso_index.time[-1].item() - enso_index.time[0].item() offset_seconds = shift * index_duration.total_seconds() - sample_times += datetime.timedelta(seconds=offset_seconds) + sample_time += datetime.timedelta(seconds=offset_seconds) with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index): enso_agg = EnsoCoefficientEvaluatorAggregator( - initial_times=sample_times.isel(time=0), + initial_time=sample_time.isel(time=0), n_forward_timesteps=(n_times - 1), timestep=datetime.timedelta(hours=6), - area_weights=area_weights, + gridded_operations=LatLonOperations(area_weights), ) - enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data) + enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data) target_coefficients, gen_coefficients = enso_agg._get_coefficients() - overlap = 1.0 - shift + overlap = max(1.0 - shift, 0.0) if overlap < OVERLAP_THRESHOLD: # should be empty dict assert not target_coefficients assert not gen_coefficients else: + assert target_coefficients is not None + assert gen_coefficients is not None assert isinstance(target_coefficients["a"], torch.Tensor) assert isinstance(gen_coefficients["a"], torch.Tensor) @@ -180,17 +197,17 @@ def test_enso_agg_calendar(calendar): n_samples, n_times, n_lat, n_lon = 2, 28, 3, 3 data_scale = 3 area_weights = torch.ones([n_lat, n_lon]) - enso_index, sample_times, target_data, gen_data = _get_data( + enso_index, sample_time, target_data, gen_data = _get_data( data_scale, n_samples, n_times, n_lat, n_lon, calendar=calendar ) enso_index = enso_index.convert_calendar("julian", dim="time", use_cftime=True) with change_aggregator_enso_index(EnsoCoefficientEvaluatorAggregator, enso_index): enso_agg = EnsoCoefficientEvaluatorAggregator( - initial_times=sample_times.isel(time=0), + initial_time=sample_time.isel(time=0), n_forward_timesteps=(n_times - 1), timestep=datetime.timedelta(hours=6), - area_weights=area_weights, + gridded_operations=LatLonOperations(area_weights), ) - enso_agg.record_batch(time=sample_times, target_data=target_data, gen_data=gen_data) + enso_agg.record_batch(time=sample_time, target_data=target_data, gen_data=gen_data) target_coefficients, gen_coefficients = enso_agg._get_coefficients() assert (target_coefficients is not None) and (gen_coefficients is not None) diff --git a/fme/fme/core/aggregator/inference/histogram.py b/fme/fme/ace/aggregator/inference/histogram.py similarity index 97% rename from fme/fme/core/aggregator/inference/histogram.py rename to fme/fme/ace/aggregator/inference/histogram.py index c7d7354..5338bb3 100644 --- a/fme/fme/core/aggregator/inference/histogram.py +++ b/fme/fme/ace/aggregator/inference/histogram.py @@ -12,7 +12,6 @@ def __init__(self): @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, diff --git a/fme/fme/ace/aggregator/inference/main.py b/fme/fme/ace/aggregator/inference/main.py new file mode 100644 index 0000000..83589f0 --- /dev/null +++ b/fme/fme/ace/aggregator/inference/main.py @@ -0,0 +1,719 @@ +import dataclasses +import datetime +import warnings +from typing import ( + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Union, +) + +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState +from fme.core.coordinates import ( + HorizontalCoordinates, + HybridSigmaPressureCoordinate, + LatLonCoordinates, +) +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.generics.aggregator import ( + InferenceAggregatorABC, + InferenceLog, + InferenceLogs, +) +from fme.core.gridded_ops import GriddedOperations +from fme.core.typing_ import TensorDict, TensorMapping +from fme.core.wandb import Table, WandB + +from ..one_step.reduced import MeanAggregator as OneStepMeanAggregator +from .annual import GlobalMeanAnnualAggregator +from .enso import EnsoCoefficientEvaluatorAggregator +from .histogram import HistogramAggregator +from .reduced import MeanAggregator, SingleTargetMeanAggregator +from .seasonal import SeasonalAggregator +from .spectrum import PairedSphericalPowerSpectrumAggregator +from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator +from .video import VideoAggregator +from .zonal_mean import ZonalMeanAggregator + +wandb = WandB.get_instance() +APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730) +SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800) + + +class _Aggregator(Protocol): + @torch.no_grad() + def record_batch( + self, + data: TensorMapping, + ): ... + + @torch.no_grad() + def get_logs(self, label: str): ... + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: ... + + +class _EvaluatorAggregator(Protocol): + @torch.no_grad() + def record_batch( + self, + target_data: TensorMapping, + gen_data: TensorMapping, + target_data_norm: TensorMapping, + gen_data_norm: TensorMapping, + i_time_start: int = 0, + ): ... + + @torch.no_grad() + def get_logs(self, label: str): ... + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: ... + + +class _TimeDependentAggregator(Protocol): + @torch.no_grad() + def record_batch( + self, + time: xr.DataArray, + data: TensorMapping, + ): ... + + @torch.no_grad() + def get_logs(self, label: str): ... + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: ... + + +class _TimeDependentEvaluatorAggregator(Protocol): + @torch.no_grad() + def record_batch( + self, + time: xr.DataArray, + target_data: TensorMapping, + gen_data: TensorMapping, + ): ... + + @torch.no_grad() + def get_logs(self, label: str): ... + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: ... + + +@dataclasses.dataclass +class InferenceEvaluatorAggregatorConfig: + """ + Configuration for inference evaluator aggregator. + + Parameters: + log_histograms: Whether to log histograms of the targets and predictions. + log_video: Whether to log videos of the state evolution. + log_extended_video: Whether to log wandb videos of the predictions with + statistical metrics, only done if log_video is True. + log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a + time dimension. + log_seasonal_means: Whether to log seasonal mean metrics and images. + log_global_mean_time_series: Whether to log global mean time series metrics. + log_global_mean_norm_time_series: Whether to log the normalized global mean + time series metrics. + monthly_reference_data: Path to monthly reference data to compare against. + time_mean_reference_data: Path to reference time means to compare against. + """ + + log_histograms: bool = False + log_video: bool = False + log_extended_video: bool = False + log_zonal_mean_images: bool = True + log_seasonal_means: bool = False + log_global_mean_time_series: bool = True + log_global_mean_norm_time_series: bool = True + monthly_reference_data: Optional[str] = None + time_mean_reference_data: Optional[str] = None + + def build( + self, + vertical_coordinate: HybridSigmaPressureCoordinate, + horizontal_coordinates: HorizontalCoordinates, + timestep: datetime.timedelta, + n_timesteps: int, + initial_time: xr.DataArray, + normalize: Callable[[TensorMapping], TensorDict], + record_step_20: bool = False, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, + channel_mean_names: Optional[Sequence[str]] = None, + ) -> "InferenceEvaluatorAggregator": + if self.monthly_reference_data is None: + monthly_reference_data = None + else: + monthly_reference_data = xr.open_dataset(self.monthly_reference_data) + if self.time_mean_reference_data is None: + time_mean = None + else: + time_mean = xr.open_dataset(self.time_mean_reference_data) + + if n_timesteps > 2**15 and self.log_zonal_mean_images: + # matplotlib raises an error if image size is too large, and we plot + # one pixel per timestep in the zonal mean images. + warnings.warn( + "Disabling zonal mean images logging due to large number of timesteps" + f" (n_timesteps={n_timesteps}). Set log_zonal_mean_images=False or " + "decrease n_timesteps to below 2**15 to avoid this warning." + ) + log_zonal_mean_images = False + else: + log_zonal_mean_images = self.log_zonal_mean_images + + return InferenceEvaluatorAggregator( + vertical_coordinate=vertical_coordinate, + horizontal_coordinates=horizontal_coordinates, + timestep=timestep, + n_timesteps=n_timesteps, + initial_time=initial_time, + log_histograms=self.log_histograms, + log_video=self.log_video, + enable_extended_videos=self.log_extended_video, + log_zonal_mean_images=log_zonal_mean_images, + log_seasonal_means=self.log_seasonal_means, + log_global_mean_time_series=self.log_global_mean_time_series, + log_global_mean_norm_time_series=self.log_global_mean_norm_time_series, + monthly_reference_data=monthly_reference_data, + time_mean_reference_data=time_mean, + record_step_20=record_step_20, + variable_metadata=variable_metadata, + channel_mean_names=channel_mean_names, + normalize=normalize, + ) + + +class InferenceEvaluatorAggregator( + InferenceAggregatorABC[ + Union[PairedData, PrognosticState], + PairedData, + ] +): + """ + Aggregates statistics for inference comparing a generated and target series. + + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__( + self, + vertical_coordinate: HybridSigmaPressureCoordinate, + horizontal_coordinates: HorizontalCoordinates, + timestep: datetime.timedelta, + n_timesteps: int, + initial_time: xr.DataArray, + normalize: Callable[[TensorMapping], TensorDict], + record_step_20: bool = False, + log_video: bool = False, + enable_extended_videos: bool = False, + log_zonal_mean_images: bool = False, + log_seasonal_means: bool = False, + log_global_mean_time_series: bool = True, + log_global_mean_norm_time_series: bool = True, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, + monthly_reference_data: Optional[xr.Dataset] = None, + log_histograms: bool = False, + time_mean_reference_data: Optional[xr.Dataset] = None, + channel_mean_names: Optional[Sequence[str]] = None, + ): + """ + Args: + vertical_coordinate: Data vertical coordinate. + horizontal_coordinates: Data horizontal coordinates + timestep: Timestep of the model. + n_timesteps: Number of timesteps of inference that will be run. + initial_time: Initial time for each sample. + normalize: Normalization function to use. + record_step_20: Whether to record the mean of the 20th steps. + log_video: Whether to log videos of the state evolution. + enable_extended_videos: Whether to log videos of statistical + metrics of state evolution + log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a + time dimension. + log_seasonal_means: Whether to log seasonal means metrics and images. + log_global_mean_time_series: Whether to log global mean time series metrics. + log_global_mean_norm_time_series: Whether to log the normalized global mean + time series metrics. + variable_metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + monthly_reference_data: Reference monthly data for computing target stats. + log_histograms: Whether to aggregate histograms. + data_grid: The grid type of the data, used for spherical power spectrum. + time_mean_reference_data: Reference time means for computing bias stats. + channel_mean_names: Names over which to compute channel means. If not + provided, all available variables will be used. + """ + self._channel_mean_names = channel_mean_names + self._aggregators: Dict[str, _EvaluatorAggregator] = {} + self._time_dependent_aggregators: Dict[ + str, _TimeDependentEvaluatorAggregator + ] = {} + ops = horizontal_coordinates.gridded_operations + self._log_time_series = ( + log_global_mean_time_series or log_global_mean_norm_time_series + ) + if log_global_mean_time_series: + self._aggregators["mean"] = MeanAggregator( + ops, + target="denorm", + n_timesteps=n_timesteps, + variable_metadata=variable_metadata, + ) + if log_global_mean_norm_time_series: + self._aggregators["mean_norm"] = MeanAggregator( + ops, + target="norm", + n_timesteps=n_timesteps, + variable_metadata=variable_metadata, + ) + if record_step_20: + self._aggregators["mean_step_20"] = OneStepMeanAggregator( + ops, target_time=20 + ) + if isinstance(horizontal_coordinates, LatLonCoordinates): + if log_zonal_mean_images: + self._aggregators["zonal_mean"] = ZonalMeanAggregator( + n_timesteps=n_timesteps, + variable_metadata=variable_metadata, + ) + self._aggregators["spherical_power_spectrum"] = ( + PairedSphericalPowerSpectrumAggregator( + horizontal_coordinates.area_weights.shape[-2], + horizontal_coordinates.area_weights.shape[-1], + horizontal_coordinates.grid, + ) + ) + if log_video: + self._aggregators["video"] = VideoAggregator( + n_timesteps=n_timesteps, + enable_extended_videos=enable_extended_videos, + variable_metadata=variable_metadata, + ) + self._aggregators["time_mean"] = TimeMeanEvaluatorAggregator( + ops, + horizontal_dims=horizontal_coordinates.dims, + variable_metadata=variable_metadata, + reference_means=time_mean_reference_data, + ) + self._aggregators["time_mean_norm"] = TimeMeanEvaluatorAggregator( + ops, + horizontal_dims=horizontal_coordinates.dims, + target="norm", + variable_metadata=variable_metadata, + ) + if log_histograms: + self._aggregators["histogram"] = HistogramAggregator() + if log_seasonal_means: + self._time_dependent_aggregators["seasonal"] = SeasonalAggregator( + ops=ops, + variable_metadata=variable_metadata, + ) + if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS: + self._time_dependent_aggregators["annual"] = GlobalMeanAnnualAggregator( + ops=ops, + timestep=timestep, + variable_metadata=variable_metadata, + monthly_reference_data=monthly_reference_data, + ) + if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS: + self._time_dependent_aggregators["enso_coefficient"] = ( + EnsoCoefficientEvaluatorAggregator( + initial_time, + n_timesteps - 1, + timestep, + gridded_operations=ops, + variable_metadata=variable_metadata, + ) + ) + self._summary_aggregators = { + name: agg + for name, agg in list(self._aggregators.items()) + + list(self._time_dependent_aggregators.items()) + if name not in ["mean", "mean_norm"] + } + self._n_timesteps_seen = 0 + self._normalize = normalize + + @property + def log_time_series(self) -> bool: + return self._log_time_series + + @torch.no_grad() + def record_batch( + self, + data: PairedData, + ) -> InferenceLogs: + if len(data.prediction) == 0: + raise ValueError("No prediction values in data") + if len(data.target) == 0: + raise ValueError("No target values in data") + target_data = {k: v for k, v in data.target.items() if k in data.prediction} + target_data_norm = self._normalize(target_data) + gen_data_norm = self._normalize(data.prediction) + for aggregator in self._aggregators.values(): + aggregator.record_batch( + target_data=target_data, + gen_data=data.prediction, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + i_time_start=self._n_timesteps_seen, + ) + for time_dependent_aggregator in self._time_dependent_aggregators.values(): + time_dependent_aggregator.record_batch( + time=data.time, + target_data=target_data, + gen_data=data.prediction, + ) + n_times = data.time.shape[1] + logs = self._get_inference_logs_slice( + step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), + ) + self._n_timesteps_seen += n_times + return logs + + def record_initial_condition( + self, + initial_condition: Union[PairedData, PrognosticState], + ) -> InferenceLogs: + if self._n_timesteps_seen != 0: + raise RuntimeError( + "record_initial_condition may only be called once, " + "before recording any batches" + ) + if isinstance(initial_condition, PairedData): + target_data = initial_condition.target + target_data_norm = self._normalize(target_data) + gen_data = initial_condition.prediction + gen_data_norm = self._normalize(gen_data) + n_times = initial_condition.time.shape[1] + else: + batch_data = initial_condition.as_batch_data() + target_data = batch_data.data + target_data_norm = self._normalize(target_data) + gen_data = target_data + gen_data_norm = target_data_norm + n_times = batch_data.time.shape[1] + for aggregator_name in ["mean", "mean_norm"]: + aggregator = self._aggregators.get(aggregator_name) + if aggregator is not None: + aggregator.record_batch( + target_data=target_data, + gen_data=gen_data, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + i_time_start=0, + ) + logs = self._get_inference_logs_slice( + step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), + ) + self._n_timesteps_seen = n_times + return logs + + def get_summary_logs(self) -> InferenceLog: + logs = {} + for name, aggregator in self._summary_aggregators.items(): + logs.update(aggregator.get_logs(label=name)) + return logs + + @torch.no_grad() + def _get_logs(self): + """ + Returns logs as can be reported to WandB. + """ + logs = {} + for name, aggregator in self._aggregators.items(): + logs.update(aggregator.get_logs(label=name)) + for name, time_dependent_aggregator in self._time_dependent_aggregators.items(): + logs.update(time_dependent_aggregator.get_logs(label=name)) + return logs + + @torch.no_grad() + def _get_inference_logs_slice(self, step_slice: slice): + """ + Returns a subset of the time series for applicable metrics + for a specific slice of as can be reported to WandB. + + Args: + step_slice: Timestep slice to determine the time series subset. + + Returns: + Tuple of start index and list of logs. + """ + logs = {} + for name, aggregator in self._aggregators.items(): + if isinstance(aggregator, MeanAggregator): + logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) + return to_inference_logs(logs) + + @torch.no_grad() + def get_datasets( + self, excluded_aggregators: Optional[Iterable[str]] = None + ) -> Dict[str, xr.Dataset]: + """ + Returns datasets from combined aggregators. + + Args: + excluded_aggregators: aggregator names for which `get_dataset` + should not be called and no output should be returned. + + Returns: + Dictionary of datasets from aggregators. + """ + if excluded_aggregators is None: + excluded_aggregators = [] + + combined_aggregators: Dict[ + str, Union[_Aggregator, _TimeDependentAggregator] + ] = { + **self._aggregators, + **self._time_dependent_aggregators, + } + return { + name: agg.get_dataset() + for name, agg in combined_aggregators.items() + if name not in excluded_aggregators + } + + +def to_inference_logs( + log: Mapping[str, Union[Table, float, int]], +) -> List[Dict[str, Union[float, int]]]: + # we have a dictionary which contains WandB tables + # which we will convert to a list of dictionaries, one for each + # row in the tables. Any scalar values will be reported in the last + # dictionary. + n_rows = 0 + for val in log.values(): + if isinstance(val, Table): + n_rows = max(n_rows, len(val.data)) + logs: List[Dict[str, Union[float, int]]] = [] + for i in range(max(1, n_rows)): # need at least one for non-series values + logs.append({}) + for key, val in log.items(): + if isinstance(val, Table): + for i, row in enumerate(val.data): + for j, col in enumerate(val.columns): + key_without_table_name = key[: key.rfind("/")] + logs[i][f"{key_without_table_name}/{col}"] = row[j] + else: + logs[-1][key] = val + return logs + + +def table_to_logs(table: Table) -> List[Dict[str, Union[float, int]]]: + """ + Converts a WandB table into a list of dictionaries. + """ + logs = [] + for row in table.data: + logs.append({table.columns[i]: row[i] for i in range(len(row))}) + return logs + + +@dataclasses.dataclass +class InferenceAggregatorConfig: + """ + Configuration for inference aggregator. + + Parameters: + time_mean_reference_data: Path to reference time means to compare against. + log_global_mean_time_series: Whether to log global mean time series metrics. + """ + + time_mean_reference_data: Optional[str] = None + log_global_mean_time_series: bool = True + + def build( + self, + gridded_operations: GriddedOperations, + n_timesteps: int, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, + ) -> "InferenceAggregator": + if self.time_mean_reference_data is not None: + time_means = xr.open_dataset(self.time_mean_reference_data) + else: + time_means = None + return InferenceAggregator( + gridded_operations=gridded_operations, + n_timesteps=n_timesteps, + variable_metadata=variable_metadata, + time_mean_reference_data=time_means, + log_global_mean_time_series=self.log_global_mean_time_series, + ) + + +class InferenceAggregator( + InferenceAggregatorABC[ + PrognosticState, + BatchData, + ] +): + """ + Aggregates statistics on a single timeseries of data. + + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__( + self, + gridded_operations: GriddedOperations, + n_timesteps: int, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, + time_mean_reference_data: Optional[xr.Dataset] = None, + log_global_mean_time_series: bool = True, + ): + """ + Args: + gridded_operations: Gridded operations for computing horizontal reductions. + n_timesteps: Number of timesteps in the model. + variable_metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + time_mean_reference_data: Reference time means for computing bias stats. + log_global_mean_time_series: Whether to log global mean time series metrics. + """ + self._log_time_series = log_global_mean_time_series + aggregators: Dict[str, _Aggregator] = {} + if log_global_mean_time_series: + aggregators["mean"] = SingleTargetMeanAggregator( + gridded_operations, + n_timesteps=n_timesteps, + ) + aggregators["time_mean"] = TimeMeanAggregator( + gridded_operations=gridded_operations, + variable_metadata=variable_metadata, + reference_means=time_mean_reference_data, + ) + self._aggregators = aggregators + self._summary_aggregators = {"time_mean": aggregators["time_mean"]} + self._n_timesteps_seen = 0 + + @property + def log_time_series(self) -> bool: + return self._log_time_series + + @torch.no_grad() + def record_batch( + self, + data: BatchData, + ) -> InferenceLogs: + """ + Record a batch of data. + + Args: + data: Batch of data to record. + """ + if len(data.data) == 0: + raise ValueError("data is empty") + for aggregator in self._aggregators.values(): + aggregator.record_batch( + data=data.data, + i_time_start=self._n_timesteps_seen, + ) + n_times = data.time.shape[1] + logs = self._get_inference_logs_slice( + step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), + ) + self._n_timesteps_seen += n_times + return logs + + def record_initial_condition( + self, + initial_condition: PrognosticState, + ) -> InferenceLogs: + if self._n_timesteps_seen != 0: + raise RuntimeError( + "record_initial_condition may only be called once, " + "before recording any batches" + ) + batch_data = initial_condition.as_batch_data() + if "mean" in self._aggregators: + self._aggregators["mean"].record_batch( + data=batch_data.data, + i_time_start=0, + ) + n_times = batch_data.time.shape[1] + logs = self._get_inference_logs_slice( + step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), + ) + self._n_timesteps_seen = n_times + return logs + + def get_summary_logs(self) -> InferenceLog: + logs = {} + for name, aggregator in self._summary_aggregators.items(): + logs.update(aggregator.get_logs(label=name)) + return logs + + @torch.no_grad() + def _get_logs(self): + """ + Returns logs as can be reported to WandB. + """ + logs = {} + for name, aggregator in self._aggregators.items(): + logs.update(aggregator.get_logs(label=name)) + return logs + + @torch.no_grad() + def _get_inference_logs(self) -> List[Dict[str, Union[float, int]]]: + """ + Returns a list of logs to report to WandB. + + This is done because in inference, we use the wandb step + as the time step, meaning we need to re-organize the logged data + from tables into a list of dictionaries. + """ + return to_inference_logs(self._get_logs()) + + @torch.no_grad() + def _get_inference_logs_slice(self, step_slice: slice): + """ + Returns a subset of the time series for applicable metrics + for a specific slice of as can be reported to WandB. + + Args: + step_slice: Timestep slice to determine the time series subset. + """ + logs = {} + for name, aggregator in self._aggregators.items(): + if isinstance(aggregator, SingleTargetMeanAggregator): + logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) + return to_inference_logs(logs) + + @torch.no_grad() + def get_datasets( + self, excluded_aggregators: Optional[Iterable[str]] = None + ) -> Dict[str, xr.Dataset]: + """ + Returns datasets from combined aggregators. + + Args: + excluded_aggregators: aggregator names for which `get_dataset` + should not be called and no output should be returned. + + Returns: + Dictionary of datasets from aggregators. + """ + if excluded_aggregators is None: + excluded_aggregators = [] + + return { + name: agg.get_dataset() + for name, agg in self._aggregators.items() + if name not in excluded_aggregators + } diff --git a/fme/fme/core/aggregator/inference/reduced.py b/fme/fme/ace/aggregator/inference/reduced.py similarity index 68% rename from fme/fme/core/aggregator/inference/reduced.py rename to fme/fme/ace/aggregator/inference/reduced.py index 6eceda5..61b5680 100644 --- a/fme/fme/core/aggregator/inference/reduced.py +++ b/fme/fme/ace/aggregator/inference/reduced.py @@ -1,15 +1,15 @@ import dataclasses +from collections import defaultdict from typing import Dict, List, Literal, Mapping, Optional, Protocol import numpy as np import torch import xarray as xr -from fme.core import metrics -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.device import get_device from fme.core.distributed import Distributed -from fme.core.metrics import Dimension +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorMapping from fme.core.wandb import Table, WandB @@ -70,10 +70,7 @@ def __call__( self, truth: torch.Tensor, predicted: torch.Tensor, - weights: Optional[torch.Tensor] = None, - dim: Dimension = (), - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... class AreaWeightedSingleTargetFunction(Protocol): @@ -84,10 +81,7 @@ class AreaWeightedSingleTargetFunction(Protocol): def __call__( self, tensor: torch.Tensor, - weights: Optional[torch.Tensor] = None, - dim: Dimension = (), - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... def compute_metric_on( @@ -102,13 +96,11 @@ def compute_metric_on( def metric_wrapper( truth: torch.Tensor, predicted: torch.Tensor, - weights: Optional[torch.Tensor] = None, - dim: Dimension = (), ) -> torch.Tensor: if source == "gen": - return metric(predicted, weights=weights, dim=dim) + return metric(predicted) elif source == "target": - return metric(truth, weights=weights, dim=dim) + return metric(truth) return metric_wrapper @@ -120,12 +112,10 @@ class AreaWeightedReducedMetric: def __init__( self, - area_weights: torch.Tensor, device: torch.device, compute_metric: AreaWeightedFunction, n_timesteps: int, ): - self._area_weights = area_weights self._compute_metric = compute_metric self._total: Optional[torch.Tensor] = None self._n_batches = torch.zeros( @@ -143,17 +133,19 @@ def record(self, target: torch.Tensor, gen: torch.Tensor, i_time_start: int): i_time_start: The index of the first timestep in the batch. """ time_dim = 1 - if target.shape[time_dim] >= gen.shape[time_dim]: - new_value = self._compute_metric( - target, gen, weights=self._area_weights, dim=(-2, -1) - ).mean(dim=0) - if self._total is None: - self._total = torch.zeros( - [self._n_timesteps], dtype=new_value.dtype, device=self._device - ) - time_slice = slice(i_time_start, i_time_start + gen.shape[1]) - self._total[time_slice] += new_value - self._n_batches[time_slice] += 1 + if target.shape != gen.shape: + raise RuntimeError( + "target and gen must have the same shape, got " + f"{target.shape} and {gen.shape}" + ) + new_value = self._compute_metric(truth=target, predicted=gen).mean(dim=0) + if self._total is None: + self._total = torch.zeros( + [self._n_timesteps], dtype=new_value.dtype, device=self._device + ) + time_slice = slice(i_time_start, i_time_start + gen.shape[time_dim]) + self._total[time_slice] += new_value + self._n_batches[time_slice] += 1 def get(self) -> torch.Tensor: """Returns the mean metric across recorded batches.""" @@ -165,12 +157,12 @@ def get(self) -> torch.Tensor: class MeanAggregator: def __init__( self, - area_weights: torch.Tensor, + gridded_operations: GriddedOperations, target: Literal["norm", "denorm"], n_timesteps: int, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): - self._area_weights = area_weights + self._gridded_operations = gridded_operations self._variable_metrics: Optional[Dict[str, Dict[str, MeanMetric]]] = None self._shape_x = None self._shape_y = None @@ -178,89 +170,81 @@ def __init__( self._n_timesteps = n_timesteps self._dist = Distributed.get_instance() - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata def _get_variable_metrics(self, gen_data: TensorMapping): if self._variable_metrics is None: - self._variable_metrics = { - "weighted_rmse": {}, - "weighted_mean_gen": {}, - "weighted_mean_target": {}, - "weighted_bias": {}, - "weighted_std_gen": {}, - } - - if self._target == "denorm": - # redundant for the "norm" case - self._variable_metrics["weighted_grad_mag_percent_diff"] = {} + self._variable_metrics = {} device = get_device() - for key in gen_data: - self._variable_metrics["weighted_rmse"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + self._variable_metrics["weighted_rmse"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, - compute_metric=metrics.root_mean_squared_error, + compute_metric=self._gridded_operations.area_weighted_rmse, n_timesteps=self._n_timesteps, ) - if self._target == "denorm": - self._variable_metrics["weighted_grad_mag_percent_diff"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + ) + if self._target == "denorm": + self._variable_metrics["weighted_grad_mag_percent_diff"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, - compute_metric=metrics.gradient_magnitude_percent_diff, + compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff, # noqa: E501 n_timesteps=self._n_timesteps, ) - self._variable_metrics["weighted_mean_gen"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + ) + self._variable_metrics["weighted_mean_gen"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, compute_metric=compute_metric_on( - source="gen", metric=metrics.weighted_mean + source="gen", + metric=lambda tensor: ( + self._gridded_operations.area_weighted_mean(tensor) + ), ), n_timesteps=self._n_timesteps, ) - self._variable_metrics["weighted_mean_target"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + ) + self._variable_metrics["weighted_mean_target"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, compute_metric=compute_metric_on( - source="target", metric=metrics.weighted_mean + source="target", + metric=lambda tensor: ( + self._gridded_operations.area_weighted_mean(tensor) + ), ), n_timesteps=self._n_timesteps, ) - self._variable_metrics["weighted_bias"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + ) + self._variable_metrics["weighted_bias"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, - compute_metric=metrics.weighted_mean_bias, + compute_metric=self._gridded_operations.area_weighted_mean_bias, n_timesteps=self._n_timesteps, ) - self._variable_metrics["weighted_std_gen"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, + ) + self._variable_metrics["weighted_std_gen"] = defaultdict( + lambda: AreaWeightedReducedMetric( device=device, compute_metric=compute_metric_on( - source="gen", metric=metrics.weighted_std + source="gen", + metric=( + lambda tensor: self._gridded_operations.area_weighted_std( + tensor + ) + ), ), n_timesteps=self._n_timesteps, ) - + ) return self._variable_metrics @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, @@ -279,7 +263,7 @@ def record_batch( i_time_start=i_time_start, ) - def _get_series_data(self) -> List[_SeriesData]: + def _get_series_data(self, step_slice: Optional[slice] = None) -> List[_SeriesData]: """Converts internally stored variable_metrics to a list.""" if self._variable_metrics is None: raise ValueError("No batches have been recorded.") @@ -288,6 +272,8 @@ def _get_series_data(self) -> List[_SeriesData]: sorted_keys = sorted(list(self._variable_metrics[metric].keys())) for key in sorted_keys: arr = self._variable_metrics[metric][key].get().detach() + if step_slice is not None: + arr = arr[step_slice] datum = _SeriesData( metric_name=metric, var_name=key, @@ -297,18 +283,21 @@ def _get_series_data(self) -> List[_SeriesData]: return data @torch.no_grad() - def get_logs(self, label: str): + def get_logs(self, label: str, step_slice: Optional[slice] = None): """ Returns logs as can be reported to WandB. Args: label: Label to prepend to all log keys. + step_slice: Slice of forecast steps to log. """ logs = {} series_data: Dict[str, np.ndarray] = { - datum.get_wandb_key(): datum.data for datum in self._get_series_data() + datum.get_wandb_key(): datum.data + for datum in self._get_series_data(step_slice) } - table = data_to_table(series_data) + init_step = 0 if step_slice is None else step_slice.start + table = data_to_table(series_data, init_step) logs[f"{label}/series"] = table return logs @@ -319,30 +308,39 @@ def get_dataset(self) -> xr.Dataset: """ data_vars = {} for datum in self._get_series_data(): - metadata = self._metadata.get( + metadata = self._variable_metadata.get( datum.var_name, VariableMetadata("unknown_units", datum.var_name) ) data_vars[datum.get_xarray_key()] = xr.DataArray( datum.data, dims=["forecast_step"], attrs=metadata._asdict() ) - n_forecast_steps = len(next(iter(data_vars.values()))) - coords = {"forecast_step": np.arange(n_forecast_steps)} + if len(data_vars.values()) > 0: + n_forecast_steps = len(next(iter(data_vars.values()))) + coords = {"forecast_step": np.arange(n_forecast_steps)} + else: + coords = {"forecast_step": np.arange(0)} + return xr.Dataset(data_vars=data_vars, coords=coords) -def data_to_table(data: Dict[str, np.ndarray]) -> Table: +def data_to_table(data: Dict[str, np.ndarray], init_step: int = 0) -> Table: """ Convert a dictionary of 1-dimensional timeseries data to a wandb Table. + + Args: + data: dictionary of timeseries data. + init_step: initial step corresponding to the first row's "forecast_step" """ keys = sorted(list(data.keys())) wandb = WandB.get_instance() table = wandb.Table(columns=["forecast_step"] + keys) - for i in range(len(data[keys[0]])): - row = [i] - for key in keys: - row.append(data[key][i]) - table.add_data(*row) + if len(keys) > 0: + for i in range(len(data[keys[0]])): + row = [init_step + i] + for key in keys: + row.append(data[key][i]) + table.add_data(*row) return table @@ -353,12 +351,10 @@ class AreaWeightedSingleTargetReducedMetric: def __init__( self, - area_weights: torch.Tensor, device: torch.device, compute_metric: AreaWeightedSingleTargetFunction, n_timesteps: int, ): - self._area_weights = area_weights self._compute_metric = compute_metric self._total: Optional[torch.Tensor] = None self._n_batches = torch.zeros( @@ -374,9 +370,7 @@ def record(self, tensor: torch.Tensor, i_time_start: int): tensor: batch data. Should have shape [batch, time, height, width]. i_time_start: The index of the first timestep in the batch. """ - new_value = self._compute_metric( - tensor, weights=self._area_weights, dim=(-2, -1) - ).mean(dim=0) + new_value = self._compute_metric(tensor).mean(dim=0) if self._total is None: self._total = torch.zeros( [self._n_timesteps], dtype=new_value.dtype, device=self._device @@ -395,11 +389,11 @@ def get(self) -> torch.Tensor: class SingleTargetMeanAggregator: def __init__( self, - area_weights: torch.Tensor, + gridded_operations: GriddedOperations, n_timesteps: int, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): - self._area_weights = area_weights + self._ops = gridded_operations self._variable_metrics: Optional[ Dict[str, Dict[str, SingleTargetMeanMetric]] ] = None @@ -408,10 +402,10 @@ def __init__( self._n_timesteps = n_timesteps self._dist = Distributed.get_instance() - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata def _get_variable_metrics(self, gen_data: TensorMapping): if self._variable_metrics is None: @@ -421,23 +415,20 @@ def _get_variable_metrics(self, gen_data: TensorMapping): } device = get_device() - for key in gen_data: - self._variable_metrics["weighted_mean_gen"][ - key - ] = AreaWeightedSingleTargetReducedMetric( - area_weights=self._area_weights, + self._variable_metrics["weighted_mean_gen"] = defaultdict( + lambda: AreaWeightedSingleTargetReducedMetric( device=device, - compute_metric=metrics.weighted_mean, + compute_metric=lambda x: self._ops.area_weighted_mean(x), n_timesteps=self._n_timesteps, ) - self._variable_metrics["weighted_std_gen"][ - key - ] = AreaWeightedSingleTargetReducedMetric( - area_weights=self._area_weights, + ) + self._variable_metrics["weighted_std_gen"] = defaultdict( + lambda: AreaWeightedSingleTargetReducedMetric( device=device, - compute_metric=metrics.weighted_std, + compute_metric=lambda x: self._ops.area_weighted_std(x), n_timesteps=self._n_timesteps, ) + ) return self._variable_metrics @@ -455,7 +446,7 @@ def record_batch( i_time_start=i_time_start, ) - def _get_series_data(self) -> List[_SeriesData]: + def _get_series_data(self, step_slice: Optional[slice] = None) -> List[_SeriesData]: """Converts internally stored variable_metrics to a list.""" if self._variable_metrics is None: raise ValueError("No batches have been recorded.") @@ -464,6 +455,8 @@ def _get_series_data(self) -> List[_SeriesData]: sorted_keys = sorted(list(self._variable_metrics[metric].keys())) for key in sorted_keys: arr = self._variable_metrics[metric][key].get().detach() + if step_slice is not None: + arr = arr[step_slice] datum = _SeriesData( metric_name=metric, var_name=key, @@ -473,18 +466,21 @@ def _get_series_data(self) -> List[_SeriesData]: return data @torch.no_grad() - def get_logs(self, label: str): + def get_logs(self, label: str, step_slice: Optional[slice] = None): """ Returns logs as can be reported to WandB. Args: label: Label to prepend to all log keys. + step_slice: Slice of forecast steps to log. """ logs = {} series_data: Dict[str, np.ndarray] = { - datum.get_wandb_key(): datum.data for datum in self._get_series_data() + datum.get_wandb_key(): datum.data + for datum in self._get_series_data(step_slice) } - table = data_to_table(series_data) + init_step = 0 if step_slice is None else step_slice.start + table = data_to_table(series_data, init_step) logs[f"{label}/series"] = table return logs @@ -495,7 +491,7 @@ def get_dataset(self) -> xr.Dataset: """ data_vars = {} for datum in self._get_series_data(): - metadata = self._metadata.get( + metadata = self._variable_metadata.get( datum.var_name, VariableMetadata("unknown_units", datum.var_name) ) data_vars[datum.get_xarray_key()] = xr.DataArray( diff --git a/fme/fme/core/aggregator/inference/seasonal.py b/fme/fme/ace/aggregator/inference/seasonal.py similarity index 87% rename from fme/fme/core/aggregator/inference/seasonal.py rename to fme/fme/ace/aggregator/inference/seasonal.py index 6385e8f..522a082 100644 --- a/fme/fme/core/aggregator/inference/seasonal.py +++ b/fme/fme/ace/aggregator/inference/seasonal.py @@ -1,25 +1,27 @@ -from typing import Any, Dict, Mapping, Optional, cast +import logging +from typing import Any, Dict, Mapping, Optional, Union, cast import numpy as np import torch import xarray as xr -from fme.core import metrics -from fme.core.aggregator.plotting import plot_paneled_data -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.aggregator.plotting import plot_paneled_data +from fme.core.dataset.data_typing import VariableMetadata from fme.core.device import get_device from fme.core.distributed import Distributed +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorMapping +from fme.core.wandb import Image class SeasonalAggregator: def __init__( self, - area_weights: torch.Tensor, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + ops: GriddedOperations, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): - self.area_weights = area_weights - self._metadata = metadata + self._area_weighted_mean = ops.area_weighted_mean + self._variable_metadata = variable_metadata self._target_dataset: Optional[xr.Dataset] = None self._gen_dataset: Optional[xr.Dataset] = None @@ -84,16 +86,16 @@ def get_logs(self, label: str) -> Dict[str, Any]: target = cast(xr.Dataset, target / target["counts"]) # type: ignore gen = cast(xr.Dataset, gen / gen["counts"]) # type: ignore bias = gen - target - plots = {} - metric_logs = {} + plots: Dict[str, Image] = {} + metric_logs: Dict[str, float] = {} for name in gen.data_vars.keys(): if name == "counts": continue - if self._metadata is not None and name in self._metadata: - long_name = self._metadata[name].long_name - units = self._metadata[name].units + if self._variable_metadata is not None and name in self._variable_metadata: + long_name = self._variable_metadata[name].long_name + units = self._variable_metadata[name].units caption_name = f"{long_name} ({units})" else: caption_name = name @@ -147,10 +149,8 @@ def get_logs(self, label: str) -> Dict[str, Any]: ) plots[f"bias/{name}"] = image_err - mse_tensor = metrics.weighted_mean( + mse_tensor = self._area_weighted_mean( torch.as_tensor(bias[name].values ** 2), - weights=self.area_weights.cpu(), - dim=(-2, -1), ) for i, season in enumerate(bias[name].season.values): rmse = float(mse_tensor[i].sqrt().numpy()) @@ -158,19 +158,24 @@ def get_logs(self, label: str) -> Dict[str, Any]: rmse = float( # must compute area mean and then mean across seasons # before sqrt, so we can't use metrics.root_mean_squared_error - mse_tensor.mean() - .sqrt() - .numpy() + mse_tensor.mean().sqrt().numpy() ) metric_logs[f"time-mean-rmse/{name}"] = rmse if len(label) > 0: label = label + "/" - logs = {} + logs: Dict[str, Union[Image, float]] = {} logs.update({f"{label}{name}": plots[name] for name in plots.keys()}) logs.update({f"{label}{name}": val for name, val in metric_logs.items()}) return logs + def get_dataset(self) -> xr.Dataset: + logging.debug( + "get_dataset not implemented for SeasonalAggregator. " + "Returning an empty dataset." + ) + return xr.Dataset() + ALL_SEASONS = np.asarray(["DJF", "MAM", "JJA", "SON"]) diff --git a/fme/fme/core/aggregator/inference/spectrum.py b/fme/fme/ace/aggregator/inference/spectrum.py similarity index 98% rename from fme/fme/core/aggregator/inference/spectrum.py rename to fme/fme/ace/aggregator/inference/spectrum.py index 66334ea..00cc77f 100644 --- a/fme/fme/core/aggregator/inference/spectrum.py +++ b/fme/fme/ace/aggregator/inference/spectrum.py @@ -1,3 +1,4 @@ +import logging import warnings from collections import defaultdict from typing import Dict, Optional @@ -52,7 +53,7 @@ def get_mean(self) -> Dict[str, torch.Tensor]: class PairedSphericalPowerSpectrumAggregator: - """Record batches and return plots for paired prediction and target data""" + """Record batches and return plots for paired prediction and target data.""" def __init__(self, nlat: int, nlon: int, grid: str = "legendre-gauss"): self._gen_aggregator = SphericalPowerSpectrumAggregator(nlat, nlon, grid) @@ -61,7 +62,6 @@ def __init__(self, nlat: int, nlon: int, grid: str = "legendre-gauss"): @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, @@ -91,7 +91,7 @@ def get_logs(self, label: str) -> Dict[str, plt.Figure]: @torch.no_grad() def get_dataset(self) -> xr.Dataset: - warnings.warn( + logging.debug( "get_dataset not implemented for PairedSphericalPowerSpectrumAggregator. " "Returning an empty dataset." ) diff --git a/fme/fme/core/aggregator/inference/test_annual.py b/fme/fme/ace/aggregator/inference/test_annual.py similarity index 76% rename from fme/fme/core/aggregator/inference/test_annual.py rename to fme/fme/ace/aggregator/inference/test_annual.py index 3a653fe..92a4b49 100644 --- a/fme/fme/core/aggregator/inference/test_annual.py +++ b/fme/fme/ace/aggregator/inference/test_annual.py @@ -8,9 +8,12 @@ import xarray as xr import fme -from fme.core.aggregator.inference.annual import GlobalMeanAnnualAggregator +from fme.ace.aggregator.inference.annual import GlobalMeanAnnualAggregator +from fme.ace.testing import DimSizes, MonthlyReferenceData +from fme.core.coordinates import DimSize from fme.core.device import get_device -from fme.core.testing import DimSizes, MonthlyReferenceData, mock_distributed +from fme.core.gridded_ops import LatLonOperations +from fme.core.testing import mock_distributed TIMESTEP = datetime.timedelta(hours=6) @@ -23,20 +26,20 @@ def test_annual_aggregator(tmpdir): n_time = 365 * 4 * 2 area_weights = torch.ones(n_lat, n_lon).to(fme.get_device()) names = ["a"] + horizontal = [DimSize("grid_yt", n_lat), DimSize("grid_xt", n_lon)] monthly_reference_data = MonthlyReferenceData( path=pathlib.Path(tmpdir), names=names, dim_sizes=DimSizes( n_time=48, - n_lat=n_lat, - n_lon=n_lon, + horizontal=horizontal, nz_interface=1, ), n_ensemble=3, ) monthly_ds = xr.open_dataset(monthly_reference_data.data_filename) agg = GlobalMeanAnnualAggregator( - area_weights=area_weights, + ops=LatLonOperations(area_weights), timestep=TIMESTEP, monthly_reference_data=monthly_ds, ) @@ -77,7 +80,7 @@ def test__get_gathered_means(use_mock_distributed): n_time = 365 * 4 * 2 # two years, approximately area_weights = torch.ones(n_lat, n_lon).to(fme.get_device()) agg = GlobalMeanAnnualAggregator( - area_weights=area_weights, + ops=LatLonOperations(area_weights), timestep=TIMESTEP, ) target_data = { @@ -103,11 +106,18 @@ def test__get_gathered_means(use_mock_distributed): if use_mock_distributed: world_size = 2 with mock_distributed(world_size=world_size): - target, gen = agg._get_gathered_means() + result = agg._get_gathered_means() + assert result is not None + target, gen = result + combined = agg.get_dataset() else: world_size = 1 - target, gen = agg._get_gathered_means() - for dataset in (target, gen): - assert set(dataset.dims) == {"sample", "year"} + result = agg._get_gathered_means() + assert result is not None + target, gen = result + combined = agg.get_dataset() + for dataset in (target, gen, combined): + assert set(dataset.dims).issuperset({"sample", "year"}) assert list(dataset.year.values) == [2000, 2001, 2002] assert dataset.sizes["sample"] == n_sample * world_size + assert set(combined.coords["source"].values) == set(["target", "prediction"]) diff --git a/fme/fme/core/aggregator/inference/test_distributed.py b/fme/fme/ace/aggregator/inference/test_distributed.py similarity index 77% rename from fme/fme/core/aggregator/inference/test_distributed.py rename to fme/fme/ace/aggregator/inference/test_distributed.py index d0985a5..fce2e74 100644 --- a/fme/fme/core/aggregator/inference/test_distributed.py +++ b/fme/fme/ace/aggregator/inference/test_distributed.py @@ -1,8 +1,9 @@ import torch -from fme.core.aggregator.inference.reduced import MeanAggregator -from fme.core.aggregator.inference.time_mean import TimeMeanEvaluatorAggregator +from fme.ace.aggregator.inference.reduced import MeanAggregator +from fme.ace.aggregator.inference.time_mean import TimeMeanEvaluatorAggregator from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from fme.core.testing import mock_distributed @@ -15,9 +16,11 @@ def test_mean_metrics_call_distributed(): with mock_distributed(-1.0): data_a = torch.ones([2, 3, 4, 4], device=get_device()) area_weights = torch.ones(1).to(get_device()) - agg = MeanAggregator(area_weights, target="denorm", n_timesteps=3) + agg = MeanAggregator( + LatLonOperations(area_weights), target="denorm", n_timesteps=3 + ) sample_data = {"a": data_a} - agg.record_batch(1.0, sample_data, sample_data, sample_data, sample_data) + agg.record_batch(sample_data, sample_data, sample_data, sample_data) logs = agg.get_logs(label="metrics") table = logs["metrics/series"] # assert all data past the first column in the WandB table is -1 @@ -33,11 +36,12 @@ def test_time_mean_metrics_call_distributed(): torch.manual_seed(0) with mock_distributed(0.0) as mock: area_weights = torch.ones(1).to(get_device()) - agg = TimeMeanEvaluatorAggregator(area_weights) + agg = TimeMeanEvaluatorAggregator( + LatLonOperations(area_weights), horizontal_dims=["lat", "lon"] + ) target_data = {"a": torch.ones([2, 3, 4, 4], device=get_device())} gen_data = {"a": torch.randn([2, 3, 4, 4], device=get_device())} agg.record_batch( - loss=1.0, target_data=target_data, gen_data=gen_data, target_data_norm=target_data, diff --git a/fme/fme/ace/aggregator/inference/test_evaluator.py b/fme/fme/ace/aggregator/inference/test_evaluator.py new file mode 100644 index 0000000..de75db4 --- /dev/null +++ b/fme/fme/ace/aggregator/inference/test_evaluator.py @@ -0,0 +1,210 @@ +import datetime + +import numpy as np +import pytest +import torch +import xarray as xr + +from fme.ace.aggregator.inference import InferenceEvaluatorAggregator +from fme.ace.data_loading.batch_data import BatchData, PairedData +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.device import get_device + +TIMESTEP = datetime.timedelta(hours=6) + + +def get_zero_time(shape, dims): + return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims) + + +def test_logs_labels_exist(): + n_sample = 10 + n_time = 22 + nx = 2 + ny = 2 + nz = 3 + vertical_coordinate = HybridSigmaPressureCoordinate( + torch.arange(nz + 1), torch.arange(nz + 1) + ) + horizontal_coordinates = LatLonCoordinates( + lon=torch.arange(nx), + lat=torch.arange(ny), + loaded_lon_name="lon", + loaded_lat_name="lat", + ) + initial_time = get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]) + + agg = InferenceEvaluatorAggregator( + vertical_coordinate=vertical_coordinate, + horizontal_coordinates=horizontal_coordinates, + timestep=TIMESTEP, + n_timesteps=n_time, + initial_time=initial_time, + record_step_20=True, + log_video=True, + log_zonal_mean_images=True, + normalize=lambda x: dict(x), + ) + time = xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]) + + logs = agg.record_batch( + data=PairedData( + prediction={ + "a": torch.randn(n_sample, n_time, nx, ny, device=get_device()) + }, + target={"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}, + time=time, + ), + ) + assert len(logs) == n_time + expected_step_keys = [ + "mean/forecast_step", + "mean/weighted_mean_gen/a", + "mean/weighted_mean_target/a", + "mean/weighted_rmse/a", + "mean/weighted_std_gen/a", + "mean/weighted_bias/a", + "mean/weighted_grad_mag_percent_diff/a", + "mean_norm/forecast_step", + "mean_norm/weighted_mean_gen/a", + "mean_norm/weighted_mean_target/a", + "mean_norm/weighted_rmse/a", + "mean_norm/weighted_std_gen/a", + "mean_norm/weighted_bias/a", + ] + for log in logs: + for key in expected_step_keys: + assert key in log, key + assert len(log) == len(expected_step_keys), set(log).difference( + expected_step_keys + ) + + summary_logs = agg.get_summary_logs() + expected_keys = [ + "mean_step_20/loss", + "mean_step_20/weighted_rmse/a", + "mean_step_20/weighted_bias/a", + "mean_step_20/weighted_grad_mag_percent_diff/a", + "spherical_power_spectrum/a", + "time_mean/rmse/a", + "time_mean/bias/a", + "time_mean/bias_map/a", + "time_mean/gen_map/a", + "time_mean_norm/rmse/a", + "time_mean_norm/gen_map/a", + "time_mean_norm/rmse/channel_mean", + "zonal_mean/error/a", + "zonal_mean/gen/a", + "video/a", + ] + for key in expected_keys: + assert key in summary_logs, key + assert len(summary_logs) == len(expected_keys), set(summary_logs).difference( + expected_keys + ) + + +def test_inference_logs_labels_exist(): + n_sample = 10 + n_time = 22 + nx = 2 + ny = 2 + nz = 3 + vertical_coordinate = HybridSigmaPressureCoordinate( + torch.arange(nz + 1), torch.arange(nz + 1) + ) + horizontal_coordinates = LatLonCoordinates( + lon=torch.arange(nx), + lat=torch.arange(ny), + loaded_lon_name="lon", + loaded_lat_name="lat", + ) + initial_time = (get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]),) + agg = InferenceEvaluatorAggregator( + vertical_coordinate=vertical_coordinate, + horizontal_coordinates=horizontal_coordinates, + timestep=TIMESTEP, + n_timesteps=n_time, + initial_time=initial_time, + record_step_20=True, + log_video=True, + normalize=lambda x: dict(x), + ) + logs = agg.record_batch( + data=PairedData( + prediction={ + "a": torch.randn(n_sample, n_time, nx, ny, device=get_device()) + }, + target={"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())}, + time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]), + ), + ) + assert isinstance(logs, list) + assert len(logs) == n_time + assert "mean/weighted_bias/a" in logs[0] + assert "mean/weighted_mean_gen/a" in logs[0] + assert "mean/weighted_mean_target/a" in logs[0] + assert "mean/weighted_grad_mag_percent_diff/a" in logs[0] + assert "mean/weighted_rmse/a" in logs[0] + assert "mean_norm/weighted_bias/a" in logs[0] + assert "mean_norm/weighted_mean_gen/a" in logs[0] + assert "mean_norm/weighted_mean_target/a" in logs[0] + assert "mean_norm/weighted_rmse/a" in logs[0] + # series/table data should be rolled out, not included as a table + assert "mean/series" not in logs[0] + assert "mean_norm/series" not in logs[0] + assert "reduced/series" not in logs[0] + assert "reduced_norm/series" not in logs[0] + + +@pytest.mark.parametrize( + "window_len, n_windows", + [ + pytest.param(3, 1, id="single_window"), + pytest.param(3, 2, id="two_windows"), + ], +) +def test_inference_logs_length(window_len: int, n_windows: int): + """ + Test that the inference logs are the correct length when using one or more + windows. + """ + nz = 3 + nx, ny = 4, 4 + vertical_coordinate = HybridSigmaPressureCoordinate( + torch.arange(nz + 1), torch.arange(nz + 1) + ) + horizontal_coordinates = LatLonCoordinates( + lon=torch.arange(nx), + lat=torch.arange(ny), + loaded_lon_name="lon", + loaded_lat_name="lat", + ) + initial_time = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),) + agg = InferenceEvaluatorAggregator( + vertical_coordinate=vertical_coordinate, + horizontal_coordinates=horizontal_coordinates, + timestep=TIMESTEP, + n_timesteps=window_len * n_windows, + initial_time=initial_time, + normalize=lambda x: dict(x), + ) + target_data = BatchData.new_on_device( + data={"a": torch.zeros([2, window_len, ny, nx], device=get_device())}, + time=xr.DataArray(np.zeros((2, window_len)), dims=["sample", "time"]), + ) + i_start = 0 + for i in range(n_windows): + sample_data = {"a": torch.zeros([2, window_len, ny, nx], device=get_device())} + for i in range(window_len): + sample_data["a"][..., i, :, :] = float(i_start + i) + paired_data = PairedData.new_on_device( + prediction=sample_data, + target=target_data.data, + time=xr.DataArray(np.zeros((2, window_len)), dims=["sample", "time"]), + ) + logs = agg.record_batch( + data=paired_data, + ) + assert len(logs) == window_len + i_start += window_len diff --git a/fme/fme/ace/aggregator/inference/test_inference.py b/fme/fme/ace/aggregator/inference/test_inference.py new file mode 100644 index 0000000..3faf8dd --- /dev/null +++ b/fme/fme/ace/aggregator/inference/test_inference.py @@ -0,0 +1,106 @@ +import datetime + +import numpy as np +import torch +import xarray as xr + +import fme +from fme.ace.aggregator.inference import InferenceAggregator +from fme.ace.data_loading.batch_data import BatchData +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + +TIMESTEP = datetime.timedelta(hours=6) + + +def get_zero_time(shape, dims): + return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims) + + +def test_logs_labels_exist(): + n_sample = 10 + n_time = 22 + nx = 2 + ny = 2 + area_weights = torch.ones(ny).to(fme.get_device()) + agg = InferenceAggregator( + LatLonOperations(area_weights), + n_timesteps=n_time, + ) + gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} + time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) + logs = agg.record_batch( + BatchData(data=gen_data, time=time), + ) + assert len(logs) == n_time + expected_step_keys = [ + "mean/forecast_step", + "mean/weighted_mean_gen/a", + "mean/weighted_std_gen/a", + ] + for log in logs: + for key in expected_step_keys: + assert key in log, key + assert len(log) == len(expected_step_keys), set(log).difference( + expected_step_keys + ) + summary_logs = agg.get_summary_logs() + expected_summary_keys = ["time_mean/gen_map/a"] + for key in expected_summary_keys: + assert key in summary_logs, key + assert len(summary_logs) == len(expected_summary_keys), set( + summary_logs + ).difference(expected_summary_keys) + + +def test_logs_labels_exist_with_reference_time_means(): + n_sample = 10 + n_time = 22 + nx = 2 + ny = 2 + area_weights = torch.ones(ny).to(fme.get_device()) + reference_time_means = xr.Dataset( + { + "a": xr.DataArray( + np.random.randn(ny, nx), + dims=["grid_yt", "grid_xt"], + ) + } + ) + agg = InferenceAggregator( + LatLonOperations(area_weights), + n_timesteps=n_time, + time_mean_reference_data=reference_time_means, + ) + gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} + time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) + logs = agg.record_batch( + BatchData( + data=gen_data, + time=time, + ), + ) + assert len(logs) == n_time + expected_step_keys = [ + "mean/forecast_step", + "mean/weighted_mean_gen/a", + "mean/weighted_std_gen/a", + ] + for log in logs: + for key in expected_step_keys: + assert key in log, key + assert len(log) == len(expected_step_keys), set(log).difference( + expected_step_keys + ) + summary_logs = agg.get_summary_logs() + expected_summary_keys = [ + "time_mean/gen_map/a", + "time_mean/ref_bias_map/a", + "time_mean/ref_bias/a", + "time_mean/ref_rmse/a", + ] + for key in expected_summary_keys: + assert key in summary_logs, key + assert len(summary_logs) == len(expected_summary_keys), set( + summary_logs + ).difference(expected_summary_keys) diff --git a/fme/fme/core/aggregator/inference/test_reduced.py b/fme/fme/ace/aggregator/inference/test_reduced.py similarity index 56% rename from fme/fme/core/aggregator/inference/test_reduced.py rename to fme/fme/ace/aggregator/inference/test_reduced.py index b67fc6d..83f5e45 100644 --- a/fme/fme/core/aggregator/inference/test_reduced.py +++ b/fme/fme/ace/aggregator/inference/test_reduced.py @@ -2,11 +2,12 @@ import torch import fme -from fme.core.aggregator.inference.reduced import ( +from fme.ace.aggregator.inference.reduced import ( AreaWeightedReducedMetric, SingleTargetMeanAggregator, ) from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations def test_area_weighted_reduced_metric_includes_later_window_starts(): @@ -19,7 +20,6 @@ def compute_metric(truth, predicted, weights=None, dim=()): return truth.mean(dim=(2, 3)) metric = AreaWeightedReducedMetric( - area_weights=torch.ones([4]), device=get_device(), compute_metric=compute_metric, n_timesteps=7, @@ -49,14 +49,33 @@ def test_single_target_mean_aggregator(): nx = 2 ny = 2 area_weights = torch.ones(ny).to(fme.get_device()) + torch.manual_seed(0) agg = SingleTargetMeanAggregator( - area_weights, - n_time_per_window * n_window, + gridded_operations=LatLonOperations(area_weights), + n_timesteps=n_time_per_window * n_window, ) - data = {"a": torch.randn(n_sample, n_time_per_window, nx, ny, device=get_device())} + data_a = torch.randn(n_sample, n_time_per_window, nx, ny, device=get_device()) for i in range(n_window): - agg.record_batch(data, i_time_start=i * n_time_per_window) + data = {"a": data_a[:, i * n_time_per_window : (i + 1) * n_time_per_window]} + agg.record_batch(data=data, i_time_start=i * n_time_per_window) logs = agg.get_logs(label="test") assert "test/series" in logs + ds = agg.get_dataset() + for i in range(1, data_a.shape[1]): + raw_variable = data_a[:, i] + raw_global_mean = raw_variable.mean().cpu().numpy() + raw_global_std = ( + raw_variable.std(dim=(1, 2), correction=0).mean().cpu().numpy() + ) # metrics are mean over batch + np.testing.assert_allclose( + raw_global_std, + ds["weighted_std_gen-a"].isel(forecast_step=i).values.item(), + rtol=1e-5, + ) + np.testing.assert_allclose( + raw_global_mean, + ds["weighted_mean_gen-a"].isel(forecast_step=i).values.item(), + rtol=1e-5, + ) diff --git a/fme/fme/core/aggregator/inference/test_seasonal.py b/fme/fme/ace/aggregator/inference/test_seasonal.py similarity index 91% rename from fme/fme/core/aggregator/inference/test_seasonal.py rename to fme/fme/ace/aggregator/inference/test_seasonal.py index 7c9fe77..ed1107e 100644 --- a/fme/fme/core/aggregator/inference/test_seasonal.py +++ b/fme/fme/ace/aggregator/inference/test_seasonal.py @@ -6,8 +6,9 @@ import xarray as xr import fme -from fme.core.aggregator.inference.seasonal import SeasonalAggregator +from fme.ace.aggregator.inference.seasonal import SeasonalAggregator from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations def get_zero_time(shape, dims): @@ -23,7 +24,7 @@ def test_seasonal_aggregator(): n_time = int(365 / 10 * 2 / n_time_step + 1) * n_time_step area_weights = torch.ones(n_lat, n_lon).to(fme.get_device()) agg = SeasonalAggregator( - area_weights=area_weights, + LatLonOperations(area_weights), ) target_data = { "a": torch.randn(n_sample, n_time, n_lat, n_lon, device=get_device()) diff --git a/fme/fme/core/aggregator/inference/test_spectrum.py b/fme/fme/ace/aggregator/inference/test_spectrum.py similarity index 92% rename from fme/fme/core/aggregator/inference/test_spectrum.py rename to fme/fme/ace/aggregator/inference/test_spectrum.py index 75ffc7a..1346c41 100644 --- a/fme/fme/core/aggregator/inference/test_spectrum.py +++ b/fme/fme/ace/aggregator/inference/test_spectrum.py @@ -3,7 +3,7 @@ import torch_harmonics import fme -from fme.core.aggregator.inference.spectrum import ( +from fme.ace.aggregator.inference.spectrum import ( PairedSphericalPowerSpectrumAggregator, SphericalPowerSpectrumAggregator, ) @@ -34,6 +34,6 @@ def test_paired_spherical_power_spectrum_aggregator(): nlon = 16 agg = PairedSphericalPowerSpectrumAggregator(nlat, nlon) data = {"a": torch.randn(2, 3, nlat, nlon, device=fme.get_device())} - agg.record_batch(0.0, data, data, None, None) + agg.record_batch(data, data, None, None) result = agg.get_logs("spectrum") assert isinstance(result["spectrum/a"], plt.Figure) diff --git a/fme/fme/core/aggregator/inference/test_time_mean.py b/fme/fme/ace/aggregator/inference/test_time_mean.py similarity index 60% rename from fme/fme/core/aggregator/inference/test_time_mean.py rename to fme/fme/ace/aggregator/inference/test_time_mean.py index e91ef59..b2feef6 100644 --- a/fme/fme/core/aggregator/inference/test_time_mean.py +++ b/fme/fme/ace/aggregator/inference/test_time_mean.py @@ -1,17 +1,22 @@ import numpy as np import torch -from fme.core.aggregator.inference.time_mean import ( +from fme.ace.aggregator.inference.time_mean import ( TimeMeanAggregator, TimeMeanEvaluatorAggregator, ) from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations def test_rmse_of_time_mean_all_channels(): torch.manual_seed(0) area_weights = torch.ones(1).to(get_device()) - agg = TimeMeanEvaluatorAggregator(area_weights, target="norm") + agg = TimeMeanEvaluatorAggregator( + LatLonOperations(area_weights), + horizontal_dims=["lat", "lon"], + target="norm", + ) target_data_norm = { "a": torch.ones([2, 3, 4, 4], device=get_device()), "b": torch.ones([2, 3, 4, 4], device=get_device()) * 3, @@ -21,7 +26,6 @@ def test_rmse_of_time_mean_all_channels(): "b": torch.ones([2, 3, 4, 4], device=get_device()) * 5, } agg.record_batch( - loss=1.0, target_data=target_data_norm, gen_data=gen_data_norm, target_data_norm=target_data_norm, @@ -33,9 +37,42 @@ def test_rmse_of_time_mean_all_channels(): assert logs["time_mean_norm/rmse/channel_mean"] == 1.5 +def test_custom_channel_mean_names(): + torch.manual_seed(0) + area_weights = torch.ones(1).to(get_device()) + agg = TimeMeanEvaluatorAggregator( + LatLonOperations(area_weights), + horizontal_dims=["lat", "lon"], + target="norm", + channel_mean_names=["a"], + ) + target_data_norm = { + "a": torch.ones([2, 3, 4, 4], device=get_device()), + "b": torch.ones([2, 3, 4, 4], device=get_device()) * 3, + } + gen_data_norm = { + "a": torch.ones([2, 3, 4, 4], device=get_device()) * 2.0, + "b": torch.ones([2, 3, 4, 4], device=get_device()) * 5, + } + agg.record_batch( + target_data=target_data_norm, + gen_data=gen_data_norm, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + ) + logs = agg.get_logs(label="time_mean_norm") + assert logs["time_mean_norm/rmse/a"] == 1.0 + assert logs["time_mean_norm/rmse/b"] == 2.0 + assert logs["time_mean_norm/rmse/channel_mean"] == 1.0 + + def test_mean_all_channels_not_in_denorm(): area_weights = torch.ones(1).to(get_device()) - agg = TimeMeanEvaluatorAggregator(area_weights, target="denorm") + agg = TimeMeanEvaluatorAggregator( + LatLonOperations(area_weights), + horizontal_dims=["lat", "lon"], + target="denorm", + ) target_data = { "a": torch.ones([2, 3, 4, 4], device=get_device()), "b": torch.ones([2, 3, 4, 4], device=get_device()) * 3, @@ -45,7 +82,6 @@ def test_mean_all_channels_not_in_denorm(): "b": torch.ones([2, 3, 4, 4], device=get_device()) * 5, } agg.record_batch( - loss=1.0, target_data=target_data, gen_data=gen_data, target_data_norm=target_data, @@ -60,16 +96,19 @@ def test_mean_all_channels_not_in_denorm(): def test_bias_values(): area_weights = torch.ones(1).to(get_device()) - agg = TimeMeanEvaluatorAggregator(area_weights, target="denorm") + agg = TimeMeanEvaluatorAggregator( + LatLonOperations(area_weights), + horizontal_dims=["lat", "lon"], + target="denorm", + ) # use constant values so area-weighting doesn't matter target_data = { - "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()), + "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()), } gen_data = { - "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()), + "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()), } agg.record_batch( - loss=1.0, target_data=target_data, gen_data=gen_data, target_data_norm=target_data, @@ -93,10 +132,10 @@ def test_bias_values(): def test_aggregator_mean_values(): area_weights = torch.ones(1).to(get_device()) - agg = TimeMeanAggregator(area_weights) + agg = TimeMeanAggregator(LatLonOperations(area_weights)) # use constant values so area-weighting doesn't matter data = { - "a": torch.rand(1) * torch.ones(size=[2, 3, 4, 5], device=get_device()), + "a": (torch.rand(1) * torch.ones(size=[2, 3, 4, 5])).to(device=get_device()), } agg.record_batch( data=data, diff --git a/fme/fme/core/aggregator/inference/test_video.py b/fme/fme/ace/aggregator/inference/test_video.py similarity index 96% rename from fme/fme/core/aggregator/inference/test_video.py rename to fme/fme/ace/aggregator/inference/test_video.py index f82eafe..73a8d80 100644 --- a/fme/fme/core/aggregator/inference/test_video.py +++ b/fme/fme/ace/aggregator/inference/test_video.py @@ -4,7 +4,7 @@ import pytest import torch -from fme.core.aggregator.inference.video import VideoAggregator +from fme.ace.aggregator.inference.video import VideoAggregator from fme.core.device import get_device from fme.core.typing_ import TensorDict @@ -82,7 +82,7 @@ def test_video_data(offsets: np.ndarray): i_end = i_start + n_window_in_memory target_window, gen_window = time_select(i_start, i_end, gen, target) aggregator.record_batch( - loss=0, target_data=target_window, gen_data=gen_window, i_time_start=i_start + target_data=target_window, gen_data=gen_window, i_time_start=i_start ) data = aggregator._get_data() assert data["bias/a"].target is None @@ -122,7 +122,7 @@ def test_video_data_without_extended_videos(offsets: np.ndarray): i_end = i_start + n_window_in_memory target_window, gen_window = time_select(i_start, i_end, gen, target) aggregator.record_batch( - loss=0, target_data=target_window, gen_data=gen_window, i_time_start=i_start + target_data=target_window, gen_data=gen_window, i_time_start=i_start ) data = aggregator._get_data() assert len(data) == 1 @@ -160,7 +160,6 @@ def test_video_data_values_on_random_inputs(n_batches: int): target_window, gen_window = time_select(i_start, i_end, gen, target) for nb in range(n_batches): # shouldn't affect results to duplicate batches aggregator.record_batch( - loss=0, target_data=slice_samples( target_window, i_start=nb * samples_per_batch, diff --git a/fme/fme/core/aggregator/inference/test_zonal_mean.py b/fme/fme/ace/aggregator/inference/test_zonal_mean.py similarity index 79% rename from fme/fme/core/aggregator/inference/test_zonal_mean.py rename to fme/fme/ace/aggregator/inference/test_zonal_mean.py index 713bf84..946c689 100644 --- a/fme/fme/core/aggregator/inference/test_zonal_mean.py +++ b/fme/fme/ace/aggregator/inference/test_zonal_mean.py @@ -1,18 +1,18 @@ import torch +from fme.ace.aggregator.inference.zonal_mean import ZonalMeanAggregator from fme.core import get_device -from fme.core.aggregator.inference.zonal_mean import ZonalMeanAggregator n_sample, n_time, ny, nx = 3, 6, 10, 20 -loss = 1.0 def test_zonal_mean_dims(): agg = ZonalMeanAggregator(n_timesteps=n_time) target_data = {"a": torch.randn(n_sample, n_time, ny, nx, device=get_device())} gen_data = {"a": torch.randn(n_sample, n_time, ny, nx, device=get_device())} - agg.record_batch(loss, target_data, gen_data, target_data, gen_data, i_time_start=0) + agg.record_batch(target_data, gen_data, target_data, gen_data, i_time_start=0) for data in (agg._target_data, agg._gen_data): + assert data is not None assert data["a"].size() == ( n_sample, n_time, @@ -24,10 +24,9 @@ def test_zonal_mean_lat_varying(): agg = ZonalMeanAggregator(n_timesteps=n_time) arr = torch.arange(ny, dtype=torch.float32, device=get_device()) arr = arr[None, None, :, None].expand(n_sample, n_time, -1, nx) - agg.record_batch( - loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0 - ) + agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0) for data in (agg._target_data, agg._gen_data): + assert data is not None torch.testing.assert_close( data["a"][0, 0, :], # one time row of the zonal mean torch.arange(ny, dtype=torch.float32, device=get_device()), @@ -38,10 +37,9 @@ def test_zonal_mean_zonally_varying(): agg = ZonalMeanAggregator(n_timesteps=n_time) arr = torch.arange(nx, dtype=torch.float32, device=get_device()) arr = arr[None, None, None, :].expand(n_sample, n_time, ny, -1) - agg.record_batch( - loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0 - ) + agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0) for data in (agg._target_data, agg._gen_data): + assert data is not None torch.testing.assert_close( data["a"][0, 0, :], # one time row of the zonal mean arr.mean() * torch.ones(ny, dtype=torch.float32, device=get_device()), @@ -53,13 +51,12 @@ def test_zonal_mean_batch_varying(): for i in range(n_sample): # assume one sample per batch arr = torch.tensor(i, dtype=torch.float32, device=get_device()) arr = arr[None, None, None, None].expand(-1, n_time, ny, nx) - agg.record_batch( - loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0 - ) + agg.record_batch({"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=0) for data in (agg._target_data, agg._gen_data): + assert data is not None torch.testing.assert_close( data["a"].sum(dim=0)[0, 0], # sum over batches, then pick a time/lat point - torch.arange(n_sample, dtype=torch.float32, device=get_device()).sum() + torch.arange(n_sample, dtype=torch.float32, device=get_device()).sum(), # should be same as sum over batches ) @@ -72,9 +69,10 @@ def test_zonal_mean_mulitple_time_slices(): arr = torch.arange(ny, dtype=torch.float32, device=get_device()) arr = arr[None, None, :, None].expand(n_sample, n_time_in_memory, ny, nx) agg.record_batch( - loss, {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time + {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time ) for data in (agg._target_data, agg._gen_data): + assert data is not None torch.testing.assert_close( (data["a"] / agg._n_batches)[0, 0, :], torch.arange(ny, dtype=torch.float32, device=get_device()), diff --git a/fme/fme/core/aggregator/inference/time_mean.py b/fme/fme/ace/aggregator/inference/time_mean.py similarity index 74% rename from fme/fme/core/aggregator/inference/time_mean.py rename to fme/fme/ace/aggregator/inference/time_mean.py index ce1335b..01ba79f 100644 --- a/fme/fme/core/aggregator/inference/time_mean.py +++ b/fme/fme/ace/aggregator/inference/time_mean.py @@ -5,9 +5,9 @@ import torch import xarray as xr -from fme.core import metrics -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.distributed import Distributed +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import Image, WandB @@ -19,26 +19,27 @@ class _TargetGenPair: name: str target: torch.Tensor gen: torch.Tensor + ops: GriddedOperations def bias(self): return self.gen - self.target - def rmse(self, weights: torch.Tensor) -> float: + def rmse(self) -> float: ret = float( - metrics.root_mean_squared_error( + self.ops.area_weighted_rmse( predicted=self.gen, truth=self.target, - weights=weights, ) .cpu() .numpy() ) return ret - def weighted_mean_bias(self, weights: torch.Tensor) -> float: + def weighted_mean_bias(self) -> float: return float( - metrics.weighted_mean_bias( - predicted=self.gen, truth=self.target, weights=weights + self.ops.area_weighted_mean_bias( + predicted=self.gen, + truth=self.target, ) .cpu() .numpy() @@ -58,27 +59,27 @@ class TimeMeanAggregator: def __init__( self, - area_weights: torch.Tensor, + gridded_operations: GriddedOperations, target: Literal["norm", "denorm"] = "denorm", - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, reference_means: Optional[xr.Dataset] = None, ): """ Args: - area_weights: Area weights for each grid cell. + gridded_operations: Computes gridded operations. target: Whether to compute metrics on the normalized or denormalized data, defaults to "denorm". - metadata: Mapping of variable names their metadata that will + variable_metadata: Mapping of variable names their metadata that will used in generating logged image captions. reference_means: Dataset containing reference time-mean values for bias computation. """ - self._area_weights = area_weights + self._ops = gridded_operations self._target = target - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata # Dictionaries of tensors of shape [n_lat, n_lon] represnting time means self._data: Optional[TensorDict] = None self._n_timesteps = 0 @@ -166,6 +167,7 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]: self._reference_means[name].values, device=pred.device ), gen=pred, + ops=self._ops, ) bias_map = pair.bias().cpu().numpy() vmin_bias, vmax_bias = get_cmap_limits(bias_map, diverging=True) @@ -176,10 +178,8 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]: bias_fig, caption=self._get_caption("bias_map", name, vmin_bias, vmax_bias), ) - logs.update( - {f"ref_bias/{name}": pair.weighted_mean_bias(self._area_weights)} - ) - logs.update({f"ref_rmse/{name}": pair.rmse(self._area_weights)}) + logs.update({f"ref_bias/{name}": pair.weighted_mean_bias()}) + logs.update({f"ref_rmse/{name}": pair.rmse()}) logs.update({f"ref_bias_map/{name}": bias_image}) if len(label) != 0: @@ -187,9 +187,9 @@ def get_logs(self, label: str) -> Dict[str, Union[float, Image]]: return logs def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str: - if name in self._metadata: - caption_name = self._metadata[name].long_name - units = self._metadata[name].units + if name in self._variable_metadata: + caption_name = self._variable_metadata[name].long_name + units = self._variable_metadata[name].units else: caption_name, units = name, "unknown_units" caption = self._image_captions[key].format(name=caption_name, units=units) @@ -200,9 +200,9 @@ def get_dataset(self) -> xr.Dataset: dims = ("lat", "lon") data = {} for name, pred in self.get_data().items(): - if name in self._metadata: - long_name = self._metadata[name].long_name - units = self._metadata[name].units + if name in self._variable_metadata: + long_name = self._variable_metadata[name].long_name + units = self._variable_metadata[name].units else: long_name = name units = "unknown_units" @@ -233,43 +233,49 @@ class TimeMeanEvaluatorAggregator: def __init__( self, - area_weights: torch.Tensor, + ops: GriddedOperations, + horizontal_dims: List[str], target: Literal["norm", "denorm"] = "denorm", - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, reference_means: Optional[xr.Dataset] = None, + channel_mean_names: Optional[List[str]] = None, ): """ Args: - area_weights: Area weights for each grid cell. + ops: Computes gridded operations. + horizontal_dims: Names of the horizontal dimensions. target: Whether to compute metrics on the normalized or denormalized data, defaults to "denorm". - metadata: Mapping of variable names their metadata that will + variable_metadata: Mapping of variable names their metadata that will used in generating logged image captions. reference_means: Dataset containing reference time-mean values for bias computation. + channel_mean_names: Names of variables whose RMSE will be averaged. If + not provided, all available variables will be used. """ - self._area_weights = area_weights + self._ops = ops + self._horizontal_dims = horizontal_dims self._target = target self._dist = Distributed.get_instance() - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata # Dictionaries of tensors of shape [n_lat, n_lon] represnting time means self._target_agg = TimeMeanAggregator( - area_weights=area_weights, target=target, metadata=metadata + gridded_operations=ops, target=target, variable_metadata=variable_metadata ) self._gen_agg = TimeMeanAggregator( - area_weights=area_weights, + gridded_operations=ops, target=target, - metadata=metadata, + variable_metadata=variable_metadata, reference_means=reference_means, ) + self._channel_mean_names = channel_mean_names @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, @@ -289,7 +295,12 @@ def _get_target_gen_pairs(self) -> List[_TargetGenPair]: ret = [] for name in gen_data.keys(): ret.append( - _TargetGenPair(gen=gen_data[name], target=target_data[name], name=name) + _TargetGenPair( + gen=gen_data[name], + target=target_data[name], + name=name, + ops=self._ops, + ) ) return ret @@ -313,33 +324,33 @@ def get_logs(self, label: str) -> Dict[str, Union[float, torch.Tensor, Image]]: ), ) plt.close("all") - rmse_all_channels[pred.name] = pred.rmse(weights=self._area_weights) + rmse_all_channels[pred.name] = pred.rmse() logs.update({f"rmse/{pred.name}": rmse_all_channels[pred.name]}) if self._target == "denorm": logs.update( { f"{bias_map_key}/{pred.name}": bias_image, - f"bias/{pred.name}": pred.weighted_mean_bias( - weights=self._area_weights - ), + f"bias/{pred.name}": pred.weighted_mean_bias(), } ) if self._target == "norm": - logs.update( - { - f"rmse/channel_mean": sum(rmse_all_channels.values()) - / len(rmse_all_channels), - } - ) + metric_name = "rmse/channel_mean" + if self._channel_mean_names is None: + values_to_average = list(rmse_all_channels.values()) + else: + values_to_average = [ + rmse_all_channels[name] for name in self._channel_mean_names + ] + logs.update({metric_name: sum(values_to_average) / len(values_to_average)}) if len(label) != 0: return {f"{label}/{key}": logs[key] for key in logs} return logs def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str: - if name in self._metadata: - caption_name = self._metadata[name].long_name - units = self._metadata[name].units + if name in self._variable_metadata: + caption_name = self._variable_metadata[name].long_name + units = self._variable_metadata[name].units else: caption_name, units = name, "unknown_units" caption = self._image_captions[key].format(name=caption_name, units=units) @@ -349,27 +360,27 @@ def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str: def get_dataset(self) -> xr.Dataset: data = {} preds = self._get_target_gen_pairs() - dims = ("lat", "lon") for pred in preds: - if pred.name in self._metadata: - long_name = self._metadata[pred.name].long_name - units = self._metadata[pred.name].units + if pred.name in self._variable_metadata: + long_name = self._variable_metadata[pred.name].long_name + units = self._variable_metadata[pred.name].units else: long_name = pred.name units = "unknown_units" gen_metadata = VariableMetadata(long_name=long_name, units=units)._asdict() - bias_metadata = self._metadata.get( + bias_metadata = self._variable_metadata.get( pred.name, VariableMetadata(long_name=long_name, units=units) )._asdict() - gen_metadata = VariableMetadata(long_name=long_name, units=units)._asdict() data.update( { f"bias_map-{pred.name}": xr.DataArray( - pred.bias().cpu(), dims=dims, attrs=bias_metadata + pred.bias().cpu(), + dims=self._horizontal_dims, + attrs=bias_metadata, ), f"gen_map-{pred.name}": xr.DataArray( pred.gen.cpu(), - dims=dims, + dims=self._horizontal_dims, attrs=gen_metadata, ), } diff --git a/fme/fme/core/aggregator/inference/video.py b/fme/fme/ace/aggregator/inference/video.py similarity index 95% rename from fme/fme/core/aggregator/inference/video.py rename to fme/fme/ace/aggregator/inference/video.py index 0bb17e8..24090a8 100644 --- a/fme/fme/core/aggregator/inference/video.py +++ b/fme/fme/ace/aggregator/inference/video.py @@ -5,7 +5,7 @@ import torch import xarray as xr -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.distributed import Distributed from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import WandB @@ -216,9 +216,7 @@ def record_batch( time_slice = slice(i_time_start, i_time_start + window_steps) for name, tensor in target_data.items(): self._target_means[name][time_slice, ...] += tensor.mean(dim=0).cpu() - self._target_squares[name][time_slice, ...] += ( - (tensor**2).mean(dim=0).cpu() - ) + self._target_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu() for name, tensor in gen_data.items(): self._gen_means[name][time_slice, ...] += tensor.mean(dim=0).cpu() self._gen_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu() @@ -292,20 +290,20 @@ def __init__( self, n_timesteps: int, enable_extended_videos: bool, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): """ Args: n_timesteps: Number of timesteps of inference that will be run. enable_extended_videos: Whether to log videos of statistical metrics of state evolution - metadata: Mapping of variable names their metadata that will + variable_metadata: Mapping of variable names their metadata that will used in generating logged video captions. """ - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata self._mean_data = _MeanVideoData(n_timesteps=n_timesteps) if enable_extended_videos: self._error_data: Optional[_ErrorVideoData] = _ErrorVideoData( @@ -323,7 +321,6 @@ def __init__( @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: Optional[TensorMapping] = None, @@ -375,14 +372,14 @@ def _get_data(self) -> Mapping[str, _MaybePairedVideoData]: video_data = {} def get_units(name: str) -> Optional[str]: - if name in self._metadata: - return self._metadata[name].units + if name in self._variable_metadata: + return self._variable_metadata[name].units else: return None def get_long_name(name: str) -> Optional[str]: - if name in self._metadata: - return self._metadata[name].long_name + if name in self._variable_metadata: + return self._variable_metadata[name].long_name else: return None @@ -480,9 +477,9 @@ def _get_caption(self, name: str) -> str: caption = ( "Autoregressive (left) prediction and (right) target for {name} [{units}]" ) - if name in self._metadata: - caption_name = self._metadata[name].long_name - units = self._metadata[name].units + if name in self._variable_metadata: + caption_name = self._variable_metadata[name].long_name + units = self._variable_metadata[name].units else: caption_name, units = name, "unknown units" return caption.format(name=caption_name, units=units) diff --git a/fme/fme/core/aggregator/inference/zonal_mean.py b/fme/fme/ace/aggregator/inference/zonal_mean.py similarity index 86% rename from fme/fme/core/aggregator/inference/zonal_mean.py rename to fme/fme/ace/aggregator/inference/zonal_mean.py index 87ae6a8..5c6de1a 100644 --- a/fme/fme/core/aggregator/inference/zonal_mean.py +++ b/fme/fme/ace/aggregator/inference/zonal_mean.py @@ -6,7 +6,7 @@ import torch import xarray as xr -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.typing_ import TensorDict, TensorMapping @@ -61,32 +61,29 @@ class ZonalMeanAggregator: def __init__( self, n_timesteps: int, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, ): """ Args: n_timesteps: Number of timesteps of inference that will be run. - metadata: Mapping of variable names their metadata that will + variable_metadata: Mapping of variable names their metadata that will used in generating logged image captions. """ self._n_timesteps = n_timesteps self._dist = Distributed.get_instance() - if metadata is None: - self._metadata: Mapping[str, VariableMetadata] = {} + if variable_metadata is None: + self._variable_metadata: Mapping[str, VariableMetadata] = {} else: - self._metadata = metadata + self._variable_metadata = variable_metadata self._target_data: Optional[TensorDict] = None self._gen_data: Optional[TensorDict] = None self._n_batches = torch.zeros( n_timesteps, dtype=torch.int32, device=get_device() - )[ - None, :, None - ] # sample, time, lat + )[None, :, None] # sample, time, lat def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, @@ -107,9 +104,11 @@ def record_batch( time_slice = slice(i_time_start, i_time_start + window_steps) # we can average along longitude without area weighting for name, tensor in target_data.items(): - self._target_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) + if name in self._target_data: + self._target_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) for name, tensor in gen_data.items(): - self._gen_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) + if name in self._gen_data: + self._gen_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) self._n_batches[:, time_slice, :] += 1 def _get_data(self) -> Dict[str, _RawData]: @@ -134,7 +133,9 @@ def _get_data(self) -> Dict[str, _RawData]: .numpy() ) - metadata = self._metadata.get(name, VariableMetadata("unknown_units", name)) + metadata = self._variable_metadata.get( + name, VariableMetadata("unknown_units", name) + ) vmin, vmax = get_cmap_limits(gen) data[f"gen/{name}"] = _RawData( datum=gen, @@ -173,9 +174,9 @@ def get_dataset(self) -> xr.Dataset: return ret def _get_caption(self, key: str, varname: str, vmin: float, vmax: float) -> str: - if varname in self._metadata: - caption_name = self._metadata[varname].long_name - units = self._metadata[varname].units + if varname in self._variable_metadata: + caption_name = self._variable_metadata[varname].long_name + units = self._variable_metadata[varname].units else: caption_name, units = varname, "unknown_units" caption = self._captions[key].format(name=caption_name, units=units) diff --git a/fme/fme/core/aggregator/null.py b/fme/fme/ace/aggregator/null.py similarity index 100% rename from fme/fme/core/aggregator/null.py rename to fme/fme/ace/aggregator/null.py diff --git a/fme/fme/core/aggregator/one_step/__init__.py b/fme/fme/ace/aggregator/one_step/__init__.py similarity index 100% rename from fme/fme/core/aggregator/one_step/__init__.py rename to fme/fme/ace/aggregator/one_step/__init__.py diff --git a/fme/fme/core/aggregator/one_step/main.py b/fme/fme/ace/aggregator/one_step/main.py similarity index 66% rename from fme/fme/core/aggregator/one_step/main.py rename to fme/fme/ace/aggregator/one_step/main.py index 23d6d0a..18a5121 100644 --- a/fme/fme/core/aggregator/one_step/main.py +++ b/fme/fme/ace/aggregator/one_step/main.py @@ -1,9 +1,12 @@ -from typing import Mapping, Optional, Protocol +from typing import Dict, Mapping, Optional, Protocol +import numpy as np import torch -from fme.core.aggregator.one_step.derived import DerivedMetricsAggregator -from fme.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata +from fme.ace.stepper import TrainOutput +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.generics.aggregator import AggregatorABC +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorMapping from .map import MapAggregator @@ -12,8 +15,7 @@ class _Aggregator(Protocol): - def get_logs(self, label: str) -> TensorMapping: - ... + def get_logs(self, label: str) -> TensorMapping: ... def record_batch( self, @@ -22,11 +24,10 @@ def record_batch( gen_data: TensorMapping, target_data_norm: TensorMapping, gen_data_norm: TensorMapping, - ) -> None: - ... + ) -> None: ... -class OneStepAggregator: +class OneStepAggregator(AggregatorABC[TrainOutput]): """ Aggregates statistics for the first timestep. @@ -36,46 +37,42 @@ class OneStepAggregator: def __init__( self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - metadata: Optional[Mapping[str, VariableMetadata]] = None, + gridded_operations: GriddedOperations, + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, loss_scaling: Optional[TensorMapping] = None, ): """ Args: - area_weights: Weights for each horizontal grid coordinate - sigma_coordinates: Coordinates for defining pressure levels. - metadata: Metadata for each variable. + gridded_operations: Operations for computing metrics on gridded data. + variable_metadata: Metadata for each variable. loss_scaling: Dictionary of variables and their scaling factors used in loss computation. """ - self._aggregators: Mapping[str, _Aggregator] = { - "snapshot": SnapshotAggregator(metadata), - "mean": MeanAggregator(area_weights), - "derived": DerivedMetricsAggregator(area_weights, sigma_coordinates), - "mean_map": MapAggregator(metadata), + aggregators: Dict[str, _Aggregator] = { + "mean": MeanAggregator(gridded_operations) } + aggregators["snapshot"] = SnapshotAggregator(variable_metadata) + aggregators["mean_map"] = MapAggregator(variable_metadata) + self._aggregators = aggregators self._loss_scaling = loss_scaling or {} @torch.no_grad() def record_batch( self, - loss: float, - target_data: TensorMapping, - gen_data: TensorMapping, - target_data_norm: TensorMapping, - gen_data_norm: TensorMapping, + batch: TrainOutput, ): - if len(target_data) == 0: + if len(batch.target_data) == 0: raise ValueError("No data in target_data") - if len(gen_data) == 0: + if len(batch.gen_data) == 0: raise ValueError("No data in gen_data") + gen_data_norm = batch.normalize(batch.gen_data) + target_data_norm = batch.normalize(batch.target_data) for agg in self._aggregators.values(): agg.record_batch( - loss=loss, - target_data=target_data, - gen_data=gen_data, + loss=batch.metrics.get("loss", np.nan), + target_data=batch.target_data, + gen_data=batch.gen_data, target_data_norm=target_data_norm, gen_data_norm=gen_data_norm, ) diff --git a/fme/fme/core/aggregator/one_step/map.py b/fme/fme/ace/aggregator/one_step/map.py similarity index 98% rename from fme/fme/core/aggregator/one_step/map.py rename to fme/fme/ace/aggregator/one_step/map.py index 2adb564..81b4812 100644 --- a/fme/fme/core/aggregator/one_step/map.py +++ b/fme/fme/ace/aggregator/one_step/map.py @@ -2,7 +2,7 @@ import torch -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.distributed import Distributed from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import Image diff --git a/fme/fme/core/aggregator/one_step/reduced.py b/fme/fme/ace/aggregator/one_step/reduced.py similarity index 76% rename from fme/fme/core/aggregator/one_step/reduced.py rename to fme/fme/ace/aggregator/one_step/reduced.py index beb9932..6ce35a2 100644 --- a/fme/fme/core/aggregator/one_step/reduced.py +++ b/fme/fme/ace/aggregator/one_step/reduced.py @@ -1,11 +1,12 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional +import numpy as np import torch import xarray as xr -from fme.core import metrics from fme.core.device import get_device from fme.core.distributed import Distributed +from fme.core.gridded_ops import GriddedOperations from fme.core.typing_ import TensorMapping from .reduced_metrics import AreaWeightedReducedMetric, ReducedMetric @@ -28,12 +29,10 @@ class MeanAggregator: def __init__( self, - area_weights: torch.Tensor, + gridded_operations: GriddedOperations, target_time: int = 1, ): - self._area_weights = area_weights - self._shape_x = None - self._shape_y = None + self._gridded_operations = gridded_operations self._n_batches = 0 self._loss = torch.tensor(0.0, device=get_device()) self._variable_metrics: Optional[Dict[str, Dict[str, ReducedMetric]]] = None @@ -49,37 +48,35 @@ def _get_variable_metrics(self, gen_data: TensorMapping): } device = get_device() for key in gen_data: - self._variable_metrics["weighted_rmse"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, - device=device, - compute_metric=metrics.root_mean_squared_error, + self._variable_metrics["weighted_rmse"][key] = ( + AreaWeightedReducedMetric( + device=device, + compute_metric=self._gridded_operations.area_weighted_rmse, + ) ) - self._variable_metrics["weighted_bias"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, - device=device, - compute_metric=metrics.weighted_mean_bias, + self._variable_metrics["weighted_bias"][key] = ( + AreaWeightedReducedMetric( + device=device, + compute_metric=self._gridded_operations.area_weighted_mean_bias, + ) ) - self._variable_metrics["weighted_grad_mag_percent_diff"][ - key - ] = AreaWeightedReducedMetric( - area_weights=self._area_weights, - device=device, - compute_metric=metrics.gradient_magnitude_percent_diff, + self._variable_metrics["weighted_grad_mag_percent_diff"][key] = ( + AreaWeightedReducedMetric( + device=device, + compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff, # noqa: E501 + ) ) + return self._variable_metrics @torch.no_grad() def record_batch( self, - loss: float, target_data: TensorMapping, gen_data: TensorMapping, target_data_norm: TensorMapping, gen_data_norm: TensorMapping, + loss: torch.Tensor = torch.tensor(np.nan), i_time_start: int = 0, ): self._loss += loss @@ -102,9 +99,7 @@ def record_batch( def _get_data(self): if self._variable_metrics is None or self._n_batches == 0: raise ValueError("No batches have been recorded.") - data: Dict[str, Union[float, torch.Tensor]] = { - "loss": self._loss / self._n_batches - } + data: Dict[str, torch.Tensor] = {"loss": self._loss / self._n_batches} for metric in self._variable_metrics: for key in self._variable_metrics[metric]: data[f"{metric}/{key}"] = ( diff --git a/fme/fme/core/aggregator/one_step/reduced_metrics.py b/fme/fme/ace/aggregator/one_step/reduced_metrics.py similarity index 82% rename from fme/fme/core/aggregator/one_step/reduced_metrics.py rename to fme/fme/ace/aggregator/one_step/reduced_metrics.py index d821307..3957b3d 100644 --- a/fme/fme/core/aggregator/one_step/reduced_metrics.py +++ b/fme/fme/ace/aggregator/one_step/reduced_metrics.py @@ -4,11 +4,10 @@ to turn metric functions that may have different APIs into a common API, so that they can be iterated over and called in the same way in a loop. """ -from typing import Optional, Protocol -import torch +from typing import Protocol -from fme.core.metrics import Dimension +import torch class AreaWeightedFunction(Protocol): @@ -21,10 +20,7 @@ def __call__( self, truth: torch.Tensor, predicted: torch.Tensor, - weights: Optional[torch.Tensor] = None, - dim: Dimension = (), - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... class ReducedMetric(Protocol): @@ -52,11 +48,9 @@ class AreaWeightedReducedMetric: def __init__( self, - area_weights: torch.Tensor, device: torch.device, compute_metric: AreaWeightedFunction, ): - self._area_weights = area_weights self._compute_metric = compute_metric self._total = None self._device = device @@ -68,9 +62,7 @@ def record(self, target: torch.Tensor, gen: torch.Tensor): target: Target data. Should have shape [batch, time, height, width]. gen: Generated data. Should have shape [batch, time, height, width]. """ - new_value = self._compute_metric( - target, gen, weights=self._area_weights, dim=(-2, -1) - ).mean(dim=0) + new_value = self._compute_metric(target, gen).mean(dim=0) if self._total is None: self._total = torch.zeros_like(new_value, device=self._device) self._total += new_value diff --git a/fme/fme/core/aggregator/one_step/snapshot.py b/fme/fme/ace/aggregator/one_step/snapshot.py similarity index 67% rename from fme/fme/core/aggregator/one_step/snapshot.py rename to fme/fme/ace/aggregator/one_step/snapshot.py index 90f0bae..d066c44 100644 --- a/fme/fme/core/aggregator/one_step/snapshot.py +++ b/fme/fme/ace/aggregator/one_step/snapshot.py @@ -1,13 +1,12 @@ from typing import Dict, Mapping, Optional -import matplotlib.pyplot as plt import torch -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.typing_ import TensorMapping -from fme.core.wandb import Image, WandB +from fme.core.wandb import Image -from ..plotting import get_cmap_limits, plot_imshow +from ..plotting import plot_paneled_data class SnapshotAggregator: @@ -18,11 +17,11 @@ class SnapshotAggregator: _captions = { "full-field": ( "{name} one step full field for last sample; " - "(left) generated and (right) target [{units}]" + "(top) generated and (bottom) target [{units}]" ), "residual": ( "{name} one step residual (prediction - previous time) for last sample; " - "(left) generated and (right) target [{units}]" + "(top) generated and (bottom) target [{units}]" ), "error": ( "{name} one step full field error (generated - target) " @@ -68,7 +67,6 @@ def get_logs(self, label: str) -> Dict[str, Image]: input_time = 0 target_time = 1 image_logs = {} - wandb = WandB.get_instance() for name in self._gen_data.keys(): # use first sample in batch gen = self._gen_data[name].select(dim=time_dim, index=target_time)[0].cpu() @@ -78,42 +76,26 @@ def get_logs(self, label: str) -> Dict[str, Image]: input = ( self._target_data[name].select(dim=time_dim, index=input_time)[0].cpu() ) - gap_shape = (input.shape[-2], 4) - gap = torch.full(gap_shape, target.min()) - gap_res = torch.full(gap_shape, (target - input).min()) images = {} - images["error"] = (gen - target).numpy() - images["full-field"] = torch.cat((gen, gap, target), axis=1).numpy() - images["residual"] = torch.cat( - ( - gen - input, - gap_res, - target - input, - ), - axis=1, - ).numpy() + images["error"] = [[(gen - target).numpy()]] + images["full-field"] = [[gen.numpy()], [target.numpy()]] + images["residual"] = [[(gen - input).numpy()], [(target - input).numpy()]] for key, data in images.items(): if key == "error" or key == "residual": diverging = True - cmap = "RdBu_r" else: diverging = False - cmap = None - vmin, vmax = get_cmap_limits(data, diverging=diverging) - caption = self._get_caption(key, name, vmin, vmax) - fig = plot_imshow(data, vmin=vmin, vmax=vmax, cmap=cmap) - wandb_image = wandb.Image(fig, caption=caption) - plt.close(fig) + caption = self._get_caption(key, name) + wandb_image = plot_paneled_data(data, diverging, caption=caption) image_logs[f"image-{key}/{name}"] = wandb_image image_logs = {f"{label}/{key}": image_logs[key] for key in image_logs} return image_logs - def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str: + def _get_caption(self, key: str, name: str) -> str: if name in self._metadata: caption_name = self._metadata[name].long_name units = self._metadata[name].units else: caption_name, units = name, "unknown_units" caption = self._captions[key].format(name=caption_name, units=units) - caption += f" vmin={vmin:.4g}, vmax={vmax:.4g}." return caption diff --git a/fme/fme/ace/aggregator/one_step/test_main.py b/fme/fme/ace/aggregator/one_step/test_main.py new file mode 100644 index 0000000..84b4ddc --- /dev/null +++ b/fme/fme/ace/aggregator/one_step/test_main.py @@ -0,0 +1,88 @@ +import numpy as np +import pytest +import torch +import xarray as xr + +from fme.ace.aggregator.one_step import OneStepAggregator +from fme.ace.stepper import TrainOutput +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations + + +def test_labels_exist(): + n_sample = 10 + n_time = 3 + nx, ny = 2, 2 + loss = 1.0 + area_weights = torch.ones(ny).to(get_device()) + agg = OneStepAggregator(LatLonOperations(area_weights)) + target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} + gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} + agg.record_batch( + batch=TrainOutput( + metrics={"loss": loss}, + target_data=target_data, + gen_data=gen_data, + time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]), + normalize=lambda x: x, + ), + ) + logs = agg.get_logs(label="test") + assert "test/mean/loss" in logs + assert "test/mean/weighted_rmse/a" in logs + assert "test/mean/weighted_bias/a" in logs + assert "test/mean/weighted_grad_mag_percent_diff/a" in logs + assert "test/snapshot/image-full-field/a" in logs + assert "test/snapshot/image-residual/a" in logs + assert "test/snapshot/image-error/a" in logs + + +def test_aggregator_raises_on_no_data(): + """ + Basic test the aggregator combines loss correctly + with multiple batches and no distributed training. + """ + ny = 2 + area_weights = torch.ones(ny).to(get_device()) + agg = OneStepAggregator(LatLonOperations(area_weights)) + with pytest.raises(ValueError) as excinfo: + agg.record_batch( + batch=TrainOutput( + metrics={"loss": 1.0}, + target_data={}, + gen_data={}, + time=xr.DataArray(np.zeros((0, 0)), dims=["sample", "time"]), + normalize=lambda x: x, + ), + ) + # check that the raised exception contains the right substring + assert "No data" in str(excinfo.value) + + +def test__get_loss_scaled_mse_components(): + loss_scaling = { + "a": torch.tensor(1.0), + "b": torch.tensor(0.5), + } + agg = OneStepAggregator( + gridded_operations=LatLonOperations( + area_weights=torch.ones(10).to(get_device()) + ), + loss_scaling=loss_scaling, + ) + + logs = { + "test/mean/weighted_rmse/a": 1.0, + "test/mean/weighted_rmse/b": 4.0, + "test/mean/weighted_rmse/c": 0.0, + } + result = agg._get_loss_scaled_mse_components(logs, "test") + scaled_squared_errors_sum = (1.0 / 1.0) ** 2 + (4.0 / 0.5) ** 2 + assert ( + result["test/mean/mse_fractional_components/a"] == 1 / scaled_squared_errors_sum + ) + assert ( + result["test/mean/mse_fractional_components/b"] + == 64 / scaled_squared_errors_sum + ) + assert "test/mean/mse_fractional_components/c" not in result diff --git a/fme/fme/core/aggregator/one_step/test_reduced.py b/fme/fme/ace/aggregator/one_step/test_reduced.py similarity index 62% rename from fme/fme/core/aggregator/one_step/test_reduced.py rename to fme/fme/ace/aggregator/one_step/test_reduced.py index cbdb10e..25b8960 100644 --- a/fme/fme/core/aggregator/one_step/test_reduced.py +++ b/fme/fme/ace/aggregator/one_step/test_reduced.py @@ -2,8 +2,9 @@ import pytest import torch -from fme.core.aggregator.one_step.reduced import MeanAggregator +from fme.ace.aggregator.one_step.reduced import MeanAggregator from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from fme.core.testing import mock_distributed @@ -15,9 +16,15 @@ def test_mean_metrics_call_distributed(): """ with mock_distributed(-1.0) as mock: area_weights = torch.ones([4]).to(get_device()) - agg = MeanAggregator(area_weights) + agg = MeanAggregator(LatLonOperations(area_weights)) sample_data = {"a": torch.ones([2, 3, 4, 4], device=get_device())} - agg.record_batch(1.0, sample_data, sample_data, sample_data, sample_data) + agg.record_batch( + loss=1.0, + target_data=sample_data, + gen_data=sample_data, + target_data_norm=sample_data, + gen_data_norm=sample_data, + ) logs = agg.get_logs(label="metrics") assert logs["metrics/loss"] == -1.0 assert logs["metrics/weighted_rmse/a"] == -1.0 @@ -30,14 +37,14 @@ def test_i_time_start_gets_correct_time_one_step_windows(): # the data from the correct timestep is piped into the aggregator. target_time = 3 area_weights = torch.ones([4]).to(get_device()) - agg = MeanAggregator(area_weights, target_time=target_time) + agg = MeanAggregator(LatLonOperations(area_weights), target_time=target_time) target_data = {"a": torch.zeros([2, 1, 4, 4], device=get_device())} for i in range(5): sample_data = { "a": torch.full([2, 1, 4, 4], fill_value=float(i), device=get_device()) } agg.record_batch( - 1.0, + loss=1.0, target_data=target_data, gen_data=sample_data, target_data_norm=target_data, @@ -62,7 +69,7 @@ def test_i_time_start_gets_correct_time_longer_windows( # while this directly tests the "mean" result, this is really a test that # the data from the correct timestep is piped into the aggregator. area_weights = torch.ones([4]).to(get_device()) - agg = MeanAggregator(area_weights, target_time=target_time) + agg = MeanAggregator(LatLonOperations(area_weights), target_time=target_time) target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())} i_start = 0 for i in range(n_windows): @@ -70,7 +77,7 @@ def test_i_time_start_gets_correct_time_longer_windows( for i in range(window_len): sample_data["a"][..., i, :, :] = float(i_start + i) agg.record_batch( - 1.0, + loss=1.0, target_data=target_data, gen_data=sample_data, target_data_norm=target_data, @@ -82,3 +89,41 @@ def test_i_time_start_gets_correct_time_longer_windows( np.testing.assert_allclose( float(logs["metrics/weighted_bias/a"]), float(target_time), rtol=1e-5 ) + + +def test_loss(): + """ + Basic test the aggregator combines loss correctly + with multiple batches and no distributed training. + """ + torch.manual_seed(0) + example_data = { + "a": torch.randn(1, 2, 5, 5, device=get_device()), + } + area_weights = torch.ones(1).to(get_device()) + aggregator = MeanAggregator(LatLonOperations(area_weights)) + aggregator.record_batch( + loss=1.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + aggregator.record_batch( + loss=2.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + assert logs["metrics/loss"] == 1.5 + aggregator.record_batch( + loss=3.0, + target_data=example_data, + gen_data=example_data, + target_data_norm=example_data, + gen_data_norm=example_data, + ) + logs = aggregator.get_logs(label="metrics") + assert logs["metrics/loss"] == 2.0 diff --git a/fme/fme/core/aggregator/plotting.py b/fme/fme/ace/aggregator/plotting.py similarity index 60% rename from fme/fme/core/aggregator/plotting.py rename to fme/fme/ace/aggregator/plotting.py index a4cfc8d..de44c8b 100644 --- a/fme/fme/core/aggregator/plotting.py +++ b/fme/fme/ace/aggregator/plotting.py @@ -6,7 +6,7 @@ from matplotlib.colors import Colormap from matplotlib.figure import Figure -from fme.core.wandb import WandB +from fme.core.wandb import Image, WandB def get_cmap_limits(data: np.ndarray, diverging=False) -> Tuple[float, float]: @@ -27,6 +27,10 @@ def plot_imshow( use_colorbar: bool = True, ) -> Figure: """Plot a 2D array using imshow, ensuring figure size is same as array size.""" + min_ = np.min(data) if vmin is None else vmin + max_ = np.max(data) if vmax is None else vmax + if len(data.shape) == 3: + data = fold_healpix_data(data, fill_value=0.5 * (min_ + max_)) if flip_lat: lat_dim = -2 data = np.flip(data, axis=lat_dim) @@ -34,8 +38,6 @@ def plot_imshow( if use_colorbar: height, width = data.shape colorbar_width = max(1, int(0.025 * width)) - min_ = np.min(data) if vmin is None else vmin - max_ = np.max(data) if vmax is None else vmax range_ = np.linspace(min_, max_, height) range_ = np.repeat(range_[:, np.newaxis], repeats=colorbar_width, axis=1) range_ = np.flipud(range_) # wandb images start from top (and left) @@ -51,11 +53,55 @@ def plot_imshow( return fig +def fold_healpix_data(data: np.ndarray, fill_value: float) -> np.ndarray: + if data.shape[0] != 12: + raise ValueError( + "first dimension must be 12 (face) for healpix data, " + f"got shape {data.shape}" + ) + # we want to panel the data like this, numbered by first dimension index + # ----------------- + # | | | | | + # | | | |3 | + # ----------------- + # | | | | | + # | | |2 |7 | + # ----------------- + # | | | | | + # | |1 |6 |10 | + # ----------------- + # | | | | | + # |0 |5 |9 | | + # ----------------- + # | | | | | + # |4 |8 | | | + # ----------------- + # | | | | | + # |11 | | | | + # ----------------- + blank_panel = np.full_like(data[0], fill_value) + panels = [ + [blank_panel, blank_panel, blank_panel, data[3]], + [blank_panel, blank_panel, data[2], data[7]], + [blank_panel, data[1], data[6], data[10]], + [data[0], data[5], data[9], blank_panel], + [data[4], data[8], blank_panel, blank_panel], + [data[11], blank_panel, blank_panel, blank_panel], + ] + return np.concatenate([np.concatenate(row, axis=1) for row in panels], axis=0) + + +def fold_if_healpix_data(data: np.ndarray, fill_value: float) -> np.ndarray: + if data.shape[0] == 12: + return fold_healpix_data(data, fill_value) + return data + + def plot_paneled_data( data: List[List[np.ndarray]], diverging: bool, caption: Optional[str] = None, -) -> Figure: +) -> Image: """Plot a list of 2D data arrays in a paneled plot.""" if diverging: cmap = "RdBu_r" @@ -77,7 +123,11 @@ def plot_paneled_data( caption += f"vmin={vmin:.4g}, vmax={vmax:.4g}." - all_data = _stitch_data_panels(data, vmin=vmin) + if diverging: + fill_value = 0.5 * (vmin + vmax) + else: + fill_value = vmin + all_data = _stitch_data_panels(data, fill_value=fill_value) fig = plot_imshow(all_data, vmin=vmin, vmax=vmax, cmap=cmap) wandb = WandB.get_instance() @@ -91,10 +141,14 @@ def plot_paneled_data( return wandb_image -def _stitch_data_panels(data: List[List[np.ndarray]], vmin) -> np.ndarray: +def _stitch_data_panels(data: List[List[np.ndarray]], fill_value) -> np.ndarray: for row in data: if len(row) != len(data[0]): raise ValueError("All rows must have the same number of panels.") + data = [ + [fold_if_healpix_data(arr, fill_value=fill_value) for arr in row] + for row in data + ] n_rows = len(data) n_cols = len(data[0]) for row in data: @@ -102,14 +156,12 @@ def _stitch_data_panels(data: List[List[np.ndarray]], vmin) -> np.ndarray: if arr.shape != data[0][0].shape: raise ValueError("All panels must have the same shape.") - stitched_data = ( - np.zeros( - ( - n_rows * data[0][0].shape[0] + n_rows - 1, - n_cols * data[0][0].shape[1] + n_cols - 1, - ) - ) - + vmin + stitched_data = np.full( + ( + n_rows * data[0][0].shape[0] + n_rows - 1, + n_cols * data[0][0].shape[1] + n_cols - 1, + ), + fill_value=fill_value, ) # iterate over rows backwards, as the image starts in the bottom left diff --git a/fme/fme/ace/aggregator/test_plotting.py b/fme/fme/ace/aggregator/test_plotting.py new file mode 100644 index 0000000..0143fb6 --- /dev/null +++ b/fme/fme/ace/aggregator/test_plotting.py @@ -0,0 +1,110 @@ +import numpy as np +import pytest + +from .plotting import ( + _stitch_data_panels, + fold_healpix_data, + get_cmap_limits, + plot_imshow, + plot_paneled_data, +) + + +def test_cmap_limits(): + data = np.array([1, 2, 3]) + vmin, vmax = get_cmap_limits(data) + assert vmin == 1 + assert vmax == 3 + + +def test_cmap_limits_diverging(): + data = np.array([-1, 2, 3]) + vmin, vmax = get_cmap_limits(data, diverging=True) + assert vmin == -3 + assert vmax == 3 + + +@pytest.mark.parametrize("use_colorbar", [True, False]) +def test_plot_imshow(use_colorbar): + shape = [10, 15] + data = np.random.randn(*shape) + fig = plot_imshow(np.array(data), use_colorbar=use_colorbar) + width, height = (fig.get_size_inches() * fig.dpi).astype(int) + if use_colorbar: + # colorbar is no more than 15% of the width but greater than 0 pixels + assert shape[1] < width <= int(shape[1] * 1.15) + assert height == shape[0] + else: + assert [height, width] == shape + + +def test_fold_healpix_data(): + face_shape = [2, 3] + data = np.random.randn(12, *face_shape) + folded = fold_healpix_data(data, fill_value=0) + expected_shape = (6 * face_shape[0], 4 * face_shape[1]) + assert folded.shape == expected_shape + + +@pytest.mark.parametrize("use_colorbar", [True, False]) +def test_plot_imshow_healpix(use_colorbar): + face_shape = [4, 6] + shape = [6 * face_shape[0], 4 * face_shape[1]] + data = np.random.randn(12, *face_shape) + fig = plot_imshow(np.array(data), use_colorbar=use_colorbar) + width, height = (fig.get_size_inches() * fig.dpi).astype(int) + if use_colorbar: + # colorbar is no more than 15% of the width but greater than 0 pixels + assert shape[1] < width <= int(shape[1] * 1.15) + assert height == shape[0] + else: + assert [height, width] == shape + + +def test_stitch_data_panels(): + data = [ + [np.array([[1, 2]]), np.array([[3, 4]])], + [np.array([[5, 6]]), np.array([[7, 8]])], + ] + stitched = _stitch_data_panels(data, fill_value=1) + expected = np.array( + [ # vertical orientation is swapped as data starts from bottom-left + [5, 6, 1, 7, 8], + [1, 1, 1, 1, 1], + [1, 2, 1, 3, 4], + ] + ) + assert np.array_equal(stitched, expected) + + +@pytest.mark.parametrize( + "shape, img_shape", + [ + pytest.param( + [12, 2, 3], + [ + 27, # 3 * 4 + 1 (divider) + 3 * 4 + 2 (colorbar) + 25, # 2 * 6 + 1 (divider) + 2 * 6 + 2 (colorbar) + ], + id="healpix", + ), + pytest.param( + [2, 3], + [ + 9, # 3 + 1 (divider) + 3 + 2 (colorbar) + 5, # 2 + 1 (divider) + 2 (colorbar) + ], + id="latlon", + ), + ], +) +def test_plot_paneled_data(shape, img_shape): + panel = np.random.uniform(size=shape) + data = [ + [panel, panel], + [panel, panel], + ] + fig = plot_paneled_data(data, diverging=False) + assert np.array_equal(fig.image.size, img_shape) + fig = plot_paneled_data(data, diverging=True) + assert np.array_equal(fig.image.size, img_shape) diff --git a/fme/fme/core/aggregator/train.py b/fme/fme/ace/aggregator/train.py similarity index 73% rename from fme/fme/core/aggregator/train.py rename to fme/fme/ace/aggregator/train.py index 6959615..bcecb69 100644 --- a/fme/fme/core/aggregator/train.py +++ b/fme/fme/ace/aggregator/train.py @@ -1,10 +1,14 @@ +from typing import Dict + import torch +from fme.ace.stepper import TrainOutput from fme.core.device import get_device from fme.core.distributed import Distributed +from fme.core.generics.aggregator import AggregatorABC -class TrainAggregator: +class TrainAggregator(AggregatorABC[TrainOutput]): """ Aggregates statistics for the first timestep. @@ -17,12 +21,12 @@ def __init__(self): self._loss = torch.tensor(0.0, device=get_device()) @torch.no_grad() - def record_batch(self, loss: float): - self._loss += loss + def record_batch(self, batch: TrainOutput): + self._loss += batch.metrics["loss"] self._n_batches += 1 @torch.no_grad() - def get_logs(self, label: str): + def get_logs(self, label: str) -> Dict[str, torch.Tensor]: """ Returns logs as can be reported to WandB. diff --git a/fme/fme/core/data_loading/__init__.py b/fme/fme/ace/data_loading/__init__.py similarity index 100% rename from fme/fme/core/data_loading/__init__.py rename to fme/fme/ace/data_loading/__init__.py diff --git a/fme/fme/ace/data_loading/batch_data.py b/fme/fme/ace/data_loading/batch_data.py new file mode 100644 index 0000000..bf177ce --- /dev/null +++ b/fme/fme/ace/data_loading/batch_data.py @@ -0,0 +1,571 @@ +import dataclasses +import datetime +import logging +from typing import ( + Any, + Callable, + Collection, + Dict, + Generic, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Sized, + Tuple, + TypeVar, + Union, +) + +import numpy as np +import torch +import xarray as xr +from torch.utils.data import default_collate + +from fme.ace.requirements import PrognosticStateDataRequirements +from fme.core.coordinates import HorizontalCoordinates, HybridSigmaPressureCoordinate +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.dataset.xarray import DatasetProperties +from fme.core.device import get_device +from fme.core.generics.data import DataLoader, GriddedDataABC, InferenceDataABC +from fme.core.gridded_ops import GriddedOperations +from fme.core.typing_ import TensorDict, TensorMapping + +SelfType = TypeVar("SelfType", bound="BatchData") + + +def _check_device(data: TensorMapping, device: torch.device): + for v in data.values(): + if v.device != device: + raise ValueError(f"data must be on {device}") + + +class PrognosticState: + """ + Thin typing wrapper around BatchData to indicate that the data is a prognostic + state, such as an initial condition or final state when evolving forward in time. + """ + + def __init__(self, data: "BatchData"): + """ + Initialize the state. + + Args: + data: The data to initialize the state with. + """ + self._data = data + + def to_device(self) -> "PrognosticState": + return PrognosticState(self._data.to_device()) + + def as_batch_data(self) -> "BatchData": + return self._data + + +@dataclasses.dataclass +class BatchData: + """A container for the data and time coordinates of a batch. + + Parameters: + data: Data for each variable in each sample of shape (sample, time, ...), + concatenated along samples to make a batch. To be used directly in training, + validation, and inference. + time: An array representing time coordinates for each sample in the batch, + concatenated along samples to make a batch. To be used in writing out + inference predictions with time coordinates, not directly in ML. + horizontal_dims: Horizontal dimensions of the data. Used for writing to + netCDF files. + """ + + data: TensorMapping + time: xr.DataArray + horizontal_dims: List[str] = dataclasses.field( + default_factory=lambda: ["lat", "lon"] + ) + + @property + def dims(self) -> List[str]: + return ["sample", "time"] + self.horizontal_dims + + def to_device(self) -> "BatchData": + return self.__class__( + data={k: v.to(get_device()) for k, v in self.data.items()}, + time=self.time, + horizontal_dims=self.horizontal_dims, + ) + + @classmethod + def _get_kwargs(cls, horizontal_dims: Optional[List[str]]) -> Dict[str, Any]: + if horizontal_dims is None: + kwargs = {} + else: + kwargs = {"horizontal_dims": horizontal_dims} + return kwargs + + @classmethod + def new_on_cpu( + cls, + data: TensorMapping, + time: xr.DataArray, + horizontal_dims: Optional[List[str]] = None, + ) -> "BatchData": + _check_device(data, torch.device("cpu")) + kwargs = cls._get_kwargs(horizontal_dims) + return BatchData( + data=data, + time=time, + **kwargs, + ) + + @classmethod + def new_on_device( + cls, + data: TensorMapping, + time: xr.DataArray, + horizontal_dims: Optional[List[str]] = None, + ) -> "BatchData": + """ + Move the data to the current global device specified by get_device(). + """ + _check_device(data, get_device()) + kwargs = cls._get_kwargs(horizontal_dims) + return BatchData( + data=data, + time=time, + **kwargs, + ) + + def __post_init__(self): + if len(self.time.shape) != 2: + raise ValueError( + "Expected time to have shape (n_samples, n_times), got shape " + f"{self.time.shape}." + ) + for k, v in self.data.items(): + if v.shape[:2] != self.time.shape[:2]: + raise ValueError( + f"Data for variable {k} has shape {v.shape}, expected shape " + f"(n_samples, n_times) for time but got shape " + f"{self.time.shape}." + ) + + @classmethod + def from_sample_tuples( + cls, + samples: Sequence[Tuple[TensorMapping, xr.DataArray]], + sample_dim_name: str = "sample", + horizontal_dims: Optional[List[str]] = None, + ) -> "BatchData": + sample_data, sample_times = zip(*samples) + batch_data = default_collate(sample_data) + batch_time = xr.concat(sample_times, dim=sample_dim_name) + return BatchData.new_on_cpu( + data=batch_data, + time=batch_time, + horizontal_dims=horizontal_dims, + ) + + def compute_derived_variables( + self: SelfType, + derive_func: Callable[[TensorMapping, TensorMapping], TensorDict], + forcing_data: SelfType, + ) -> SelfType: + """ + Compute derived variables from the data and forcing data. + + The forcing data must have the same time coordinate as the batch data. + + Args: + derive_func: A function that takes the data and forcing data and returns a + dictionary of derived variables. + forcing_data: The forcing data to compute derived variables from. + """ + if not np.all(forcing_data.time.values == self.time.values): + raise ValueError( + "Forcing data must have the same time coordinate as the batch data." + ) + derived_data = derive_func(self.data, forcing_data.data) + return self.__class__( + data={**self.data, **derived_data}, + time=self.time, + horizontal_dims=self.horizontal_dims, + ) + + def remove_initial_condition(self: SelfType, n_ic_timesteps: int) -> SelfType: + """ + Remove the initial condition timesteps from the data. + """ + if n_ic_timesteps == 0: + raise RuntimeError("No initial condition timesteps to remove.") + return self.__class__( + {k: v[:, n_ic_timesteps:] for k, v in self.data.items()}, + time=self.time.isel(time=slice(n_ic_timesteps, None)), + horizontal_dims=self.horizontal_dims, + ) + + def subset_names(self: SelfType, names: Collection[str]) -> SelfType: + """ + Subset the data to only include the given names. + """ + return self.__class__( + {k: v for k, v in self.data.items() if k in names}, + time=self.time, + horizontal_dims=self.horizontal_dims, + ) + + def get_start( + self: SelfType, prognostic_names: Collection[str], n_ic_timesteps: int + ) -> PrognosticState: + """ + Get the initial condition state. + """ + return PrognosticState( + self.subset_names(prognostic_names).select_time_slice( + slice(0, n_ic_timesteps) + ) + ) + + def get_end( + self: SelfType, prognostic_names: Collection[str], n_ic_timesteps: int + ) -> PrognosticState: + """ + Get the final state which can be used as a new initial condition. + """ + return PrognosticState( + self.subset_names(prognostic_names).select_time_slice( + slice(-n_ic_timesteps, None) + ) + ) + + def select_time_slice(self: SelfType, time_slice: slice) -> SelfType: + """ + Select a window of data from the batch. + """ + return self.__class__( + {k: v[:, time_slice] for k, v in self.data.items()}, + time=self.time[:, time_slice], + horizontal_dims=self.horizontal_dims, + ) + + def prepend(self: SelfType, initial_condition: PrognosticState) -> SelfType: + """ + Prepend the initial condition to the data. + """ + initial_batch_data = initial_condition.as_batch_data() + filled_data = {**initial_batch_data.data} + example_tensor = list(initial_batch_data.data.values())[0] + state_data_device = list(self.data.values())[0].device + for k in self.data: + if k not in filled_data: + filled_data[k] = torch.full_like(example_tensor, fill_value=np.nan) + return self.__class__( + data={ + k: torch.cat([filled_data[k].to(state_data_device), v], dim=1) + for k, v in self.data.items() + }, + time=xr.concat([initial_batch_data.time, self.time], dim="time"), + horizontal_dims=self.horizontal_dims, + ) + + +@dataclasses.dataclass +class PairedData: + """A container for the data and time coordinate of a batch, with paired + prediction and target data. + """ + + prediction: TensorMapping + target: TensorMapping + time: xr.DataArray + + @classmethod + def from_batch_data( + cls, + prediction: BatchData, + target: BatchData, + ) -> "PairedData": + if not np.all(prediction.time.values == target.time.values): + raise ValueError("Prediction and target time coordinate must be the same.") + return PairedData(prediction.data, target.data, prediction.time) + + @classmethod + def new_on_device( + cls, + prediction: TensorMapping, + target: TensorMapping, + time: xr.DataArray, + ) -> "PairedData": + device = get_device() + _check_device(prediction, device) + _check_device(target, device) + return PairedData(prediction, target, time) + + @classmethod + def new_on_cpu( + cls, + prediction: TensorMapping, + target: TensorMapping, + time: xr.DataArray, + ) -> "PairedData": + _check_device(prediction, torch.device("cpu")) + _check_device(target, torch.device("cpu")) + return PairedData(prediction, target, time) + + +T = TypeVar("T", covariant=True) + + +U = TypeVar("U") + + +class SizedMap(Generic[T, U], Sized, Iterable[U]): + def __init__(self, func: Callable[[T], U], iterable: DataLoader[T]): + self._func = func + self._iterable = iterable + + def __len__(self) -> int: + return len(self._iterable) + + def __iter__(self) -> Iterator[U]: + return map(self._func, self._iterable) + + +def get_initial_condition( + loader: DataLoader[BatchData], + requirements: PrognosticStateDataRequirements, +) -> PrognosticState: + for batch in loader: + return batch.to_device().get_start( + prognostic_names=requirements.names, + n_ic_timesteps=requirements.n_timesteps, + ) + raise ValueError("No initial condition found in loader") + + +class InferenceGriddedData(InferenceDataABC[PrognosticState, BatchData]): + """ + Data as required for inference. + + All data exposed from this class is on the current device. + """ + + def __init__( + self, + loader: DataLoader[BatchData], + initial_condition: Union[PrognosticState, PrognosticStateDataRequirements], + properties: DatasetProperties, + ): + """ + Args: + loader: torch DataLoader, which returns batches of type + TensorMapping where keys indicate variable name. + Each tensor has shape + [batch_size, face, time_window_size, n_channels, n_x_coord, n_y_coord]. + Data can be on any device (but will typically be on CPU). + initial_condition: Initial condition for the inference, or a requirements + object specifying how to extract the initial condition from the first + batch of data. Data can be on any device. + properties: Batch-constant properties for the dataset, such as variable + metadata and coordinate information. Data can be on any device. + + Note: + While input data can be on any device, all data exposed from this class + will be on the current device. + """ + self._loader = loader + self._properties = properties.to_device() + self._n_initial_conditions: Optional[int] = None + if isinstance(initial_condition, PrognosticStateDataRequirements): + self._initial_condition: PrognosticState = get_initial_condition( + loader, initial_condition + ) + else: + self._initial_condition = initial_condition.to_device() + + @property + def loader(self) -> DataLoader[BatchData]: + def on_device(batch: BatchData) -> BatchData: + return batch.to_device() + + return SizedMap(on_device, self._loader) + + @property + def variable_metadata(self) -> Mapping[str, VariableMetadata]: + return self._properties.variable_metadata + + @property + def vertical_coordinate(self) -> HybridSigmaPressureCoordinate: + return self._properties.vertical_coordinate + + @property + def horizontal_coordinates(self) -> HorizontalCoordinates: + return self._properties.horizontal_coordinates + + @property + def timestep(self) -> datetime.timedelta: + return self._properties.timestep + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return { + **self.horizontal_coordinates.coords, + **self.vertical_coordinate.coords, + } + + @property + def gridded_operations(self) -> GriddedOperations: + return self.horizontal_coordinates.gridded_operations + + @property + def _n_samples(self) -> int: + return len(self._loader.dataset) # type: ignore + + @property + def _n_batches(self) -> int: + return len(self._loader) # type: ignore + + @property + def _first_time(self) -> Any: + return self._loader.dataset[0][1].values[0] # type: ignore + + @property + def _last_time(self) -> Any: + return self._loader.dataset[-1][1].values[0] # type: ignore + + @property + def n_initial_conditions(self) -> int: + if self._n_initial_conditions is None: + example_data = next(iter(self.loader)).data + example_tensor = next(iter(example_data.values())) + self._n_initial_conditions = example_tensor.shape[0] + return self._n_initial_conditions + + @property + def initial_condition(self) -> PrognosticState: + return self._initial_condition + + def log_info(self, name: str): + logging.info( + f"{name} data: {self._n_samples} samples, " f"{self._n_batches} batches" + ) + logging.info(f"{name} data: first sample's initial time: {self._first_time}") + logging.info(f"{name} data: last sample's initial time: {self._last_time}") + + +class GriddedData(GriddedDataABC[BatchData]): + """ + Data as required for pytorch training. + + The data is assumed to be gridded, and attributes are included for + performing operations on gridded data. + + All data exposed from this class is on the current device. + """ + + def __init__( + self, + loader: DataLoader[BatchData], + properties: DatasetProperties, + sampler: Optional[torch.utils.data.Sampler] = None, + ): + """ + Args: + loader: torch DataLoader, which returns batches of type + TensorMapping where keys indicate variable name. + Each tensor has shape + [batch_size, face, time_window_size, n_channels, n_x_coord, n_y_coord]. + Data can be on any device (but will typically be on CPU). + properties: Batch-constant properties for the dataset, such as variable + metadata and coordinate information. Data can be on any device. + sampler: Optional sampler for the data loader. Provided to allow support for + distributed training. + + Note: + While input data can be on any device, all data exposed from this class + will be on the current device. + """ + self._loader = loader + self._properties = properties.to_device() + self._sampler = sampler + self._batch_size: Optional[int] = None + + @property + def loader(self) -> DataLoader[BatchData]: + def on_device(batch: BatchData) -> BatchData: + return batch.to_device() + + return SizedMap(on_device, self._loader) + + @property + def variable_metadata(self) -> Mapping[str, VariableMetadata]: + return self._properties.variable_metadata + + @property + def vertical_coordinate(self) -> HybridSigmaPressureCoordinate: + return self._properties.vertical_coordinate + + @property + def horizontal_coordinates(self) -> HorizontalCoordinates: + return self._properties.horizontal_coordinates + + @property + def timestep(self) -> datetime.timedelta: + return self._properties.timestep + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return { + **self.horizontal_coordinates.coords, + **self.vertical_coordinate.coords, + } + + @property + def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]: + return self.horizontal_coordinates.grid + + @property + def gridded_operations(self) -> GriddedOperations: + return self.horizontal_coordinates.gridded_operations + + @property + def n_samples(self) -> int: + return len(self._loader.dataset) # type: ignore + + @property + def n_batches(self) -> int: + return len(self._loader) # type: ignore + + @property + def _first_time(self) -> Any: + return self._loader.dataset[0][1].values[0] # type: ignore + + @property + def _last_time(self) -> Any: + return self._loader.dataset[-1][1].values[0] # type: ignore + + @property + def batch_size(self) -> int: + if self._batch_size is None: + example_data = next(iter(self.loader)).data + example_tensor = next(iter(example_data.values())) + self._batch_size = example_tensor.shape[0] + return self._batch_size + + def log_info(self, name: str): + logging.info( + f"{name} data: {self.n_samples} samples, " f"{self.n_batches} batches" + ) + logging.info(f"{name} data: first sample's initial time: {self._first_time}") + logging.info(f"{name} data: last sample's initial time: {self._last_time}") + + def set_epoch(self, epoch: int): + """ + Set the epoch for the data loader sampler, if it is a distributed sampler. + """ + if self._sampler is not None and isinstance( + self._sampler, torch.utils.data.DistributedSampler + ): + self._sampler.set_epoch(epoch) diff --git a/fme/fme/ace/data_loading/config.py b/fme/fme/ace/data_loading/config.py new file mode 100644 index 0000000..292d356 --- /dev/null +++ b/fme/fme/ace/data_loading/config.py @@ -0,0 +1,34 @@ +import dataclasses +from typing import Optional, Sequence + +from fme.core.dataset.config import XarrayDataConfig +from fme.core.distributed import Distributed + + +@dataclasses.dataclass +class DataLoaderConfig: + """ + Parameters: + dataset: A sequence of configurations each defining a dataset + to be loaded. This sequence of datasets will be concatenated. + batch_size: Number of samples per batch. + num_data_workers: Number of parallel workers to use for data loading. + prefetch_factor: how many batches a single data worker will attempt to + hold in host memory at a given time. + strict_ensemble: Whether to enforce that the datasets to be concatened + have the same dimensions and coordinates. + """ + + dataset: Sequence[XarrayDataConfig] + batch_size: int + num_data_workers: int = 0 + prefetch_factor: Optional[int] = None + strict_ensemble: bool = True + + def __post_init__(self): + dist = Distributed.get_instance() + if self.batch_size % dist.world_size != 0: + raise ValueError( + "batch_size must be divisible by the number of parallel " + f"workers, got {self.batch_size} and {dist.world_size}" + ) diff --git a/fme/fme/ace/data_loading/getters.py b/fme/fme/ace/data_loading/getters.py new file mode 100644 index 0000000..557296f --- /dev/null +++ b/fme/fme/ace/data_loading/getters.py @@ -0,0 +1,206 @@ +import logging +from typing import Optional, Union + +import torch.utils.data +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler + +from fme.ace.data_loading.batch_data import BatchData +from fme.ace.requirements import PrognosticStateDataRequirements +from fme.core.dataset.getters import get_dataset +from fme.core.dataset.requirements import DataRequirements +from fme.core.dataset.xarray import XarrayDataset +from fme.core.device import using_gpu +from fme.core.distributed import Distributed + +from .batch_data import GriddedData, InferenceGriddedData, PrognosticState +from .config import DataLoaderConfig +from .inference import ( + ExplicitIndices, + ForcingDataLoaderConfig, + InferenceDataLoaderConfig, + InferenceDataset, +) + +logger = logging.getLogger(__name__) + + +def get_data_loader( + config: DataLoaderConfig, + train: bool, + requirements: DataRequirements, +) -> GriddedData: + """ + Args: + config: Parameters for the data loader. + train: Whether loader is intended for training or validation data; if True, + then data will be shuffled. + requirements: Data requirements for the model. + """ + dataset, properties = get_dataset( + config.dataset, requirements, strict=config.strict_ensemble + ) + dist = Distributed.get_instance() + + if dist.is_distributed(): + sampler = DistributedSampler(dataset, shuffle=train) + else: + sampler = RandomSampler(dataset) if train else None + + if properties.is_remote: + # GCSFS and S3FS are not fork-safe, so we need to use forkserver + mp_context = "forkserver" + persistent_workers = True + else: + mp_context = None + persistent_workers = False + + def collate_fn(samples): + return BatchData.from_sample_tuples( + samples, + horizontal_dims=list(properties.horizontal_coordinates.dims), + ) + + batch_size = dist.local_batch_size(int(config.batch_size)) + + if config.prefetch_factor is None: + # DataLoader default is not None so we must leave it unset + kwargs = {} + else: + kwargs = {"prefetch_factor": config.prefetch_factor} + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + num_workers=config.num_data_workers, + sampler=sampler, + drop_last=True, + pin_memory=using_gpu(), + collate_fn=collate_fn, + multiprocessing_context=mp_context, + persistent_workers=persistent_workers, + **kwargs, + ) + + if len(dataloader) == 0: + raise ValueError( + "No batches in dataloader: " + f"{len(dataloader.dataset)} samples, {len(dataloader)} batches. " + f"Batch size is {dataloader.batch_size}" + ) + + return GriddedData( + loader=dataloader, + properties=properties, + sampler=sampler, + ) + + +def get_inference_data( + config: InferenceDataLoaderConfig, + total_forward_steps: int, + window_requirements: DataRequirements, + initial_condition: Union[PrognosticState, PrognosticStateDataRequirements], + surface_temperature_name: Optional[str] = None, + ocean_fraction_name: Optional[str] = None, +) -> InferenceGriddedData: + """ + Args: + config: Parameters for the data loader. + total_forward_steps: Total number of forward steps to take over the course of + inference. + window_requirements: Data requirements for the model. + initial_condition: Initial condition for the inference, or a requirements object + specifying how to extract the initial condition from the first batch of + data + surface_temperature_name: Name of the surface temperature variable. Can be + set to None if no ocean temperature prescribing is being used. + ocean_fraction_name: Name of the ocean fraction variable. Can be set to None + if no ocean temperature prescribing is being used. + + Returns: + A data loader for inference with coordinates and metadata. + """ + dataset = InferenceDataset( + config, + total_forward_steps, + window_requirements, + surface_temperature_name, + ocean_fraction_name, + ) + properties = dataset.properties + + if properties.is_remote: + # GCSFS and S3FS are not fork-safe, so we need to use forkserver + # persist workers since startup is slow + mp_context = "forkserver" + persistent_workers = True + else: + mp_context = None + persistent_workers = False + + logging.info(f"Multiprocessing inference context: {mp_context or 'fork'}") + + # we roll our own batching in InferenceDataset, which is why batch_size=None below + loader = torch.utils.data.DataLoader( + dataset, + batch_size=None, + num_workers=config.num_data_workers, + shuffle=False, + pin_memory=using_gpu(), + multiprocessing_context=mp_context, + persistent_workers=persistent_workers, + ) + gridded_data = InferenceGriddedData( + loader=loader, + initial_condition=initial_condition, + properties=properties, + ) + + return gridded_data + + +def get_forcing_data( + config: ForcingDataLoaderConfig, + total_forward_steps: int, + window_requirements: DataRequirements, + initial_condition: PrognosticState, + surface_temperature_name: Optional[str] = None, + ocean_fraction_name: Optional[str] = None, +) -> InferenceGriddedData: + """Return a GriddedData loader for forcing data based on the initial condition. + This function determines the start indices for the forcing data based on the initial + time in the provided initial condition. + + Args: + config: Parameters for the forcing data loader. + total_forward_steps: Total number of forward steps to take over the course of + inference. + window_requirements: Data requirements for the forcing data. + initial_condition: Initial condition for the inference. + surface_temperature_name: Name of the surface temperature variable. Can be + set to None if no ocean temperature prescribing is being used. + ocean_fraction_name: Name of the ocean fraction variable. Can be set to None + if no ocean temperature prescribing is being used. + + Returns: + A data loader for forcing data with coordinates and metadata. + """ + initial_time = initial_condition.as_batch_data().time + if initial_time.shape[1] != 1: + raise NotImplementedError("code assumes initial time only has 1 timestep") + available_times = XarrayDataset(config.dataset, window_requirements).all_times + start_time_indices = [] + for time in initial_time.values[:, 0]: + start_time_indices.append(available_times.get_loc(time)) + inference_config = config.build_inference_config( + start_indices=ExplicitIndices(start_time_indices) + ) + return get_inference_data( + config=inference_config, + total_forward_steps=total_forward_steps, + window_requirements=window_requirements, + initial_condition=initial_condition, + surface_temperature_name=surface_temperature_name, + ocean_fraction_name=ocean_fraction_name, + ) diff --git a/fme/fme/ace/data_loading/inference.py b/fme/fme/ace/data_loading/inference.py new file mode 100644 index 0000000..312d286 --- /dev/null +++ b/fme/fme/ace/data_loading/inference.py @@ -0,0 +1,295 @@ +import dataclasses +import logging +from math import ceil +from typing import Optional, Sequence, Union + +import cftime +import numpy as np +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData +from fme.ace.data_loading.perturbation import SSTPerturbation +from fme.core.coordinates import LatLonCoordinates +from fme.core.dataset.config import XarrayDataConfig +from fme.core.dataset.requirements import DataRequirements +from fme.core.dataset.xarray import DatasetProperties, XarrayDataset +from fme.core.distributed import Distributed +from fme.core.typing_ import Slice + + +@dataclasses.dataclass +class TimestampList: + """ + Configuration for a list of timestamps. + + Parameters: + times: List of timestamps. + timestamp_format: Format of the timestamps. + """ + + times: Sequence[str] + timestamp_format: str = "%Y-%m-%dT%H:%M:%S" + + def as_indices(self, time_index: xr.CFTimeIndex) -> np.ndarray: + datetimes = [ + cftime.datetime.strptime( + t, self.timestamp_format, calendar=time_index.calendar + ) + for t in self.times + ] + (indices,) = time_index.isin(datetimes).nonzero() + if len(indices) != len(self.times): + missing_times = set(datetimes) - set(time_index[indices]) + raise ValueError( + f"Inference initial condition timestamps {missing_times} " + "were not found in the dataset." + ) + return indices + + @property + def n_initial_conditions(self) -> int: + return len(self.times) + + +@dataclasses.dataclass +class InferenceInitialConditionIndices: + """ + Configuration of the indices for initial conditions during inference. + + Parameters: + n_initial_conditions: Number of initial conditions to use. + first: Index of the first initial condition. + interval: Interval between initial conditions. + """ + + n_initial_conditions: int + first: int = 0 + interval: int = 1 + + def __post_init__(self): + if self.interval < 0: + raise ValueError("interval must be positive") + + def as_indices(self) -> np.ndarray: + stop = self.n_initial_conditions * self.interval + self.first + return np.arange(self.first, stop, self.interval) + + +@dataclasses.dataclass +class ExplicitIndices: + """ + Configure indices providing them explicitly. + + Parameters: + list: List of integer indices. + """ + + list: Sequence[int] + + def as_indices(self) -> np.ndarray: + return np.array(self.list) + + @property + def n_initial_conditions(self) -> int: + return len(self.list) + + +@dataclasses.dataclass +class InferenceDataLoaderConfig: + """ + Configuration for inference data. + + This is like the `DataLoaderConfig` class, but with some additional + constraints. During inference, we have only one batch, so the number of + samples directly determines the size of that batch. + + Parameters: + dataset: Configuration to define the dataset. + start_indices: Configuration of the indices for initial conditions + during inference. This can be a list of timestamps, a list of + integer indices, or a slice configuration of the integer indices. + Values following the initial condition will still come from + the full dataset. + num_data_workers: Number of parallel workers to use for data loading. + perturbations: Configuration for SST perturbations. + persistence_names: Names of variables for which all returned values + will be the same as the initial condition. When evaluating initial + condition predictability, set this to forcing variables that should + not be updated during inference (e.g. surface temperature). + """ + + dataset: XarrayDataConfig + start_indices: Union[ + InferenceInitialConditionIndices, ExplicitIndices, TimestampList + ] + num_data_workers: int = 0 + perturbations: Optional[SSTPerturbation] = None + persistence_names: Optional[Sequence[str]] = None + + def __post_init__(self): + if self.dataset.subset != Slice(None, None, None): + raise ValueError("Inference data may not be subset.") + + @property + def n_initial_conditions(self) -> int: + return self.start_indices.n_initial_conditions + + +@dataclasses.dataclass +class ForcingDataLoaderConfig: + """ + Configuration for the forcing data. + + Parameters: + dataset: Configuration to define the dataset. + num_data_workers: Number of parallel workers to use for data loading. + perturbations: Configuration for SST perturbations + used in forcing data. + persistence_names: Names of variables for which all returned values + will be the same as the initial condition. When evaluating initial + condition predictability, set this to forcing variables that should + not be updated during inference (e.g. surface temperature). + """ + + dataset: XarrayDataConfig + num_data_workers: int = 0 + perturbations: Optional[SSTPerturbation] = None + persistence_names: Optional[Sequence[str]] = None + + def __post_init__(self): + if self.dataset.subset != Slice(None, None, None): + raise ValueError("Inference data may not be subset.") + + def build_inference_config(self, start_indices: ExplicitIndices): + return InferenceDataLoaderConfig( + dataset=self.dataset, + num_data_workers=self.num_data_workers, + start_indices=start_indices, + perturbations=self.perturbations, + persistence_names=self.persistence_names, + ) + + +class InferenceDataset(torch.utils.data.Dataset): + def __init__( + self, + config: InferenceDataLoaderConfig, + total_forward_steps: int, + requirements: DataRequirements, + surface_temperature_name: Optional[str] = None, + ocean_fraction_name: Optional[str] = None, + ): + dataset = XarrayDataset(config.dataset, requirements=requirements) + self._dataset = dataset + self._properties = dataset.properties + self._forward_steps_in_memory = requirements.n_timesteps - 1 + self._total_forward_steps = total_forward_steps + self._perturbations = config.perturbations + self._surface_temperature_name = surface_temperature_name + self._ocean_fraction_name = ocean_fraction_name + self._n_initial_conditions = config.n_initial_conditions + + if isinstance(config.start_indices, TimestampList): + self._start_indices = config.start_indices.as_indices(dataset.all_times) + else: + self._start_indices = config.start_indices.as_indices() + self._validate_n_forward_steps() + if isinstance(self._properties.horizontal_coordinates, LatLonCoordinates): + self._lats, self._lons = self._properties.horizontal_coordinates.meshgrid + else: + if self._perturbations is not None: + raise ValueError( + "Currently, SST perturbations are only supported \ + for lat/lon coordinates." + ) + if self._perturbations is not None and ( + self._surface_temperature_name is None or self._ocean_fraction_name is None + ): + raise ValueError( + "No ocean configuration found, \ + SST perturbations require an ocean configuration." + ) + + self._persistence_data: Optional[BatchData] = None + if config.persistence_names is not None: + first_sample = self._get_batch_data(0) + self._persistence_data = first_sample.subset_names( + config.persistence_names + ).select_time_slice(slice(0, 1)) + + def _get_batch_data(self, index) -> BatchData: + dist = Distributed.get_instance() + i_start = index * self._forward_steps_in_memory + sample_tuples = [] + for i_member in range(self._n_initial_conditions): + # check if sample is one this local rank should process + if i_member % dist.world_size != dist.rank: + continue + i_window_start = i_start + self._start_indices[i_member] + i_window_end = i_window_start + self._forward_steps_in_memory + 1 + if i_window_end > ( + self._total_forward_steps + self._start_indices[i_member] + ): + i_window_end = ( + self._total_forward_steps + self._start_indices[i_member] + 1 + ) + window_time_slice = slice(i_window_start, i_window_end) + tensors, time = self._dataset.get_sample_by_time_slice(window_time_slice) + if self._perturbations is not None: + if ( + self._surface_temperature_name is None + or self._ocean_fraction_name is None + ): + raise ValueError( + "Surface temperature and ocean fraction names must be provided \ + to apply SST perturbations." + ) + logging.debug("Applying SST perturbations to forcing data") + for perturbation in self._perturbations.perturbations: + perturbation.apply_perturbation( + tensors[self._surface_temperature_name], + self._lats, + self._lons, + tensors[self._ocean_fraction_name], + ) + sample_tuples.append((tensors, time)) + return BatchData.from_sample_tuples( + sample_tuples, + horizontal_dims=list(self.properties.horizontal_coordinates.dims), + ) + + def __getitem__(self, index) -> BatchData: + dist = Distributed.get_instance() + result = self._get_batch_data(index) + if self._persistence_data is not None: + updated_data = {} + for key, value in self._persistence_data.data.items(): + updated_data[key] = value.expand_as(result.data[key]) + result.data = {**result.data, **updated_data} + assert result.time.shape[0] == self._n_initial_conditions // dist.world_size + return result + + def __len__(self) -> int: + # The ceil is necessary so if the last batch is smaller + # than the rest the ratio will be rounded up and the last batch + # will be included in the loading + return int(ceil(self._total_forward_steps / self._forward_steps_in_memory)) + + @property + def properties(self) -> DatasetProperties: + return self._properties + + @property + def n_forward_steps(self) -> int: + return self._total_forward_steps + + def _validate_n_forward_steps(self): + max_steps = self._dataset.total_timesteps - max(self._start_indices) - 1 + if self._total_forward_steps > max_steps: + raise ValueError( + f"The number of forward inference steps ({self._total_forward_steps}) " + "must be less than or equal to the number of possible steps " + f"({max_steps}) in dataset after the last initial condition's " + "start index." + ) diff --git a/fme/fme/ace/data_loading/perturbation.py b/fme/fme/ace/data_loading/perturbation.py new file mode 100644 index 0000000..d21021d --- /dev/null +++ b/fme/fme/ace/data_loading/perturbation.py @@ -0,0 +1,194 @@ +import abc +import dataclasses +from typing import Any, Callable, ClassVar, Mapping, Tuple, Type, TypeVar + +import dacite +import numpy as np +import torch + +from fme.core.registry.registry import Registry + + +@dataclasses.dataclass +class PerturbationConfig(abc.ABC): + """ + Returns a perturbation function config class. + """ + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "PerturbationConfig": + """ + Create a PerturbationSelector from a dictionary containing all the information + needed to build a PerturbationConfig. + """ + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + @abc.abstractmethod + def apply_perturbation( + self, + data: torch.Tensor, + lat: torch.Tensor, + lon: torch.Tensor, + ocean_fraction: torch.Tensor, + ) -> None: ... + + +PT = TypeVar("PT", bound=Type[PerturbationConfig]) + + +@dataclasses.dataclass +class PerturbationSelector: + type: str + config: Mapping[str, Any] + registry: ClassVar[Registry] = Registry() + + def __post__init(self): + if self.registry is not Registry(): + raise ValueError("PerturbationSelector.registry should not be set manually") + + @classmethod + def register(cls, type_name) -> Callable[[PT], PT]: + return cls.registry.register(type_name) + + def build(self) -> PerturbationConfig: + return self.registry.from_dict(self.get_state()) + + def get_state(self) -> Mapping[str, Any]: + """ + Get a dictionary containing all the information needed to build a + PerturbationConfig. + + """ + return {"type": self.type, "config": self.config} + + @classmethod + def get_available_types(cls): + """This class method is used to expose all available types of Perturbations.""" + return cls(type="", config={}).registry._types.keys() + + +@dataclasses.dataclass +class SSTPerturbation: + """ + Configuration for sea surface temperature perturbations + applied to initial condition and forcing data. + Currently, this is strictly applied to both. + + Parameters: + sst: List of perturbation selectors for SST perturbations. + """ + + sst: list[PerturbationSelector] + + def __post_init__(self): + self.perturbations: list[PerturbationConfig] = [ + perturbation.build() for perturbation in self.sst + ] + + +def _get_ocean_mask(ocean_fraction: torch.Tensor, cutoff: float = 0.5) -> torch.Tensor: + return ocean_fraction > cutoff + + +@PerturbationSelector.register("constant") +@dataclasses.dataclass +class ConstantConfig(PerturbationConfig): + """ + Configuration for a constant perturbation. + """ + + amplitude: float = 1.0 + + def apply_perturbation( + self, + data: torch.Tensor, + lat: torch.Tensor, + lon: torch.Tensor, + ocean_fraction: torch.Tensor, + ): + ocean_mask = _get_ocean_mask(ocean_fraction) + data[ocean_mask] += self.amplitude # type: ignore + + +@PerturbationSelector.register("greens_function") +@dataclasses.dataclass +class GreensFunctionConfig(PerturbationConfig): + """ + Configuration for a single sinusoidal patch of a Green's function perturbation. + See equation 1 in Bloch‐Johnson, J., et al. (2024). + + Parameters: + amplitude: The amplitude of the perturbation, + maximum is reached at (lat_center, lon_center). + lat_center: The latitude at the center of the patch in degrees. + lon_center: The longitude at the center of the patch in degrees. + lat_width: latitudinal width of the patch in degrees. + lon_width: longitudinal width of the patch in degrees. + """ + + amplitude: float = 1.0 + lat_center: float = 0.0 + lon_center: float = 0.0 + lat_width: float = 10.0 + lon_width: float = 10.0 + + def __post_init__(self): + self._lat_center_rad = np.deg2rad(self.lat_center) + self._lon_center_rad = np.deg2rad(self.lon_center) + self._lat_width_rad = np.deg2rad(self.lat_width) + self._lon_width_rad = np.deg2rad(self.lon_width) + + def _wrap_longitude_discontinuity( + self, + lon: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Assume longitude is in the range [0, 360) degrees. + If the patch crosses the discontinuity at 0/360 degrees, + shift the longitude accordingly. + """ + lon_min = self.lon_center - self.lon_width / 2.0 + lon_max = self.lon_center + self.lon_width / 2.0 + if lon_min < 0: + lon_shifted = ((lon + 180) % 360) - 180 + lon_in_patch = (lon_shifted > lon_min) & (lon_shifted < lon_max) + elif lon_max > 360: + lon_in_patch = (lon > lon_min) | (lon < lon_max % 360) + lon_shifted = ((lon + 180) % 360) + 180 + else: + lon_in_patch = (lon > lon_min) & (lon < lon_max) + lon_shifted = lon + return lon_in_patch, lon_shifted + + def apply_perturbation( + self, + data: torch.Tensor, + lat: torch.Tensor, + lon: torch.Tensor, + ocean_fraction: torch.Tensor, + ): + lat_in_patch = torch.abs(lat - self.lat_center) < self.lat_width / 2.0 + lon_in_patch, lon_shifted = self._wrap_longitude_discontinuity(lon) + mask = lat_in_patch & lon_in_patch + ocean_mask = _get_ocean_mask(ocean_fraction) + perturbation = self.amplitude * ( + torch.cos( + torch.pi + / 2 + * (lat.deg2rad() - self._lat_center_rad) + / (self._lat_width_rad / 2.0) + ) + ** 2 + * torch.cos( + torch.pi + / 2 + * (lon_shifted.deg2rad() - self._lon_center_rad) + / (self._lon_width_rad / 2.0) + ) + ** 2 + ) + mask = mask.expand(data.shape) + perturbation = perturbation.expand(data.shape) + data[mask & ocean_mask] += perturbation[mask & ocean_mask] diff --git a/fme/fme/ace/data_loading/test_batch_data.py b/fme/fme/ace/data_loading/test_batch_data.py new file mode 100644 index 0000000..d4a2643 --- /dev/null +++ b/fme/fme/ace/data_loading/test_batch_data.py @@ -0,0 +1,147 @@ +from typing import List + +import numpy as np +import pytest +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData +from fme.core.device import get_device + + +def assert_metadata_equal(a: BatchData, b: BatchData): + assert a.horizontal_dims == b.horizontal_dims + + +def get_batch_data( + names: List[str], + n_samples: int, + n_times: int, + horizontal_dims: List[str], + n_lat: int = 8, + n_lon: int = 16, +): + device = get_device() + return BatchData( + data={ + name: torch.randn(n_samples, n_times, n_lat, n_lon, device=device) + for name in names + }, + time=xr.DataArray(np.random.rand(n_samples, n_times), dims=["sample", "time"]), + horizontal_dims=horizontal_dims, + ) + + +@pytest.mark.parametrize( + "names, prognostic_names", + [ + pytest.param(["prog"], ["prog"], id="all prognostic"), + pytest.param(["prog", "forcing"], ["prog"], id="some prognostic"), + pytest.param(["forcing1", "forcing2"], [], id="no prognostic"), + ], +) +@pytest.mark.parametrize("n_ic_timesteps", [1, 2]) +def test_get_start(names: List[str], prognostic_names: List[str], n_ic_timesteps: int): + n_samples = 2 + n_times = 5 + n_lat = 8 + n_lon = 16 + horizontal_dims = ["lat", "lon"] + batch_data = get_batch_data( + names=names, + n_samples=n_samples, + n_times=n_times, + horizontal_dims=horizontal_dims, + n_lat=n_lat, + n_lon=n_lon, + ) + start = batch_data.get_start(prognostic_names, n_ic_timesteps).as_batch_data() + assert_metadata_equal(start, batch_data) + assert start.time.equals(batch_data.time.isel(time=slice(0, n_ic_timesteps))) + assert set(start.data.keys()) == set(prognostic_names) + for name in prognostic_names: + assert start.data[name].shape == (n_samples, n_ic_timesteps, n_lat, n_lon) + np.testing.assert_allclose( + start.data[name].cpu().numpy(), + batch_data.data[name][:, :n_ic_timesteps, ...].cpu().numpy(), + ) + + +@pytest.mark.parametrize( + "names, prognostic_names", + [ + pytest.param(["prog"], ["prog"], id="all prognostic"), + pytest.param(["prog", "forcing"], ["prog"], id="some prognostic"), + pytest.param(["forcing1", "forcing2"], [], id="no prognostic"), + ], +) +@pytest.mark.parametrize("n_ic_timesteps", [1, 2]) +def test_get_end(names: List[str], prognostic_names: List[str], n_ic_timesteps: int): + n_samples = 2 + n_times = 5 + n_lat = 8 + n_lon = 16 + horizontal_dims = ["lat", "lon"] + batch_data = get_batch_data( + names=names, + n_samples=n_samples, + n_times=n_times, + horizontal_dims=horizontal_dims, + n_lat=n_lat, + n_lon=n_lon, + ) + end = batch_data.get_end(prognostic_names, n_ic_timesteps).as_batch_data() + assert_metadata_equal(end, batch_data) + assert end.time.equals(batch_data.time.isel(time=slice(-n_ic_timesteps, None))) + assert set(end.data.keys()) == set(prognostic_names) + for name in prognostic_names: + assert end.data[name].shape == (n_samples, n_ic_timesteps, n_lat, n_lon) + np.testing.assert_allclose( + end.data[name].cpu().numpy(), + batch_data.data[name][:, -n_ic_timesteps:, ...].cpu().numpy(), + ) + + +@pytest.mark.parametrize( + "names, prepend_names", + [ + pytest.param(["prog"], ["prog"], id="all prepended"), + pytest.param(["prog", "forcing"], ["prog"], id="some prepended"), + ], +) +@pytest.mark.parametrize("n_ic_timesteps", [1, 2]) +def test_prepend(names: List[str], prepend_names: List[str], n_ic_timesteps: int): + n_samples = 2 + n_times = 5 + n_lat = 8 + n_lon = 16 + horizontal_dims = ["lat", "lon"] + batch_data = get_batch_data( + names=names, + n_samples=n_samples, + n_times=n_times, + horizontal_dims=horizontal_dims, + n_lat=n_lat, + n_lon=n_lon, + ) + start_data = batch_data.get_start(prepend_names, n_ic_timesteps) + prepended = batch_data.prepend(start_data) + start_batch_data = start_data.as_batch_data() + assert_metadata_equal(prepended, batch_data) + assert prepended.time.isel(time=slice(n_ic_timesteps, None)).equals(batch_data.time) + assert set(prepended.data.keys()) == set(names) + for name in names: + np.testing.assert_allclose( + prepended.data[name][:, n_ic_timesteps:, ...].cpu().numpy(), + batch_data.data[name].cpu().numpy(), + ) + for name in prepend_names: + np.testing.assert_allclose( + prepended.data[name][:, :n_ic_timesteps, ...].cpu().numpy(), + start_batch_data.data[name].cpu().numpy(), + ) + assert prepended.time.shape == (n_samples, n_times + n_ic_timesteps) + for name in set(names) - set(prepend_names): + assert np.all( + np.isnan(prepended.data[name][:, :n_ic_timesteps, ...].cpu().numpy()) + ) diff --git a/fme/fme/core/data_loading/test_data_loader.py b/fme/fme/ace/data_loading/test_data_loader.py similarity index 59% rename from fme/fme/core/data_loading/test_data_loader.py rename to fme/fme/ace/data_loading/test_data_loader.py index 81ce1ed..3d261df 100644 --- a/fme/fme/core/data_loading/test_data_loader.py +++ b/fme/fme/ace/data_loading/test_data_loader.py @@ -1,8 +1,8 @@ """This file contains unit tests related to creating torch Datasets from climate data (e.g. netCDF files).""" -import datetime import math +import os import pathlib from typing import List @@ -12,55 +12,68 @@ import torch import xarray as xr -from fme.core.data_loading.config import DataLoaderConfig, Slice, XarrayDataConfig -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.data_loading.getters import ( +import fme +from fme.ace.data_loading.batch_data import BatchData, PrognosticState +from fme.ace.data_loading.config import DataLoaderConfig +from fme.ace.data_loading.getters import ( get_data_loader, get_forcing_data, get_inference_data, ) -from fme.core.data_loading.inference import ( +from fme.ace.data_loading.inference import ( ExplicitIndices, ForcingDataLoaderConfig, InferenceDataLoaderConfig, + InferenceDataset, InferenceInitialConditionIndices, TimestampList, ) -from fme.core.data_loading.requirements import DataRequirements -from fme.core.data_loading.utils import BatchData, get_times +from fme.ace.data_loading.perturbation import PerturbationSelector, SSTPerturbation +from fme.ace.requirements import PrognosticStateDataRequirements +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.dataset.config import XarrayDataConfig +from fme.core.dataset.requirements import DataRequirements +from fme.core.typing_ import Slice -def _get_coords(dim_sizes, calendar): +def _get_coords(dim_sizes, calendar, timestep_size=1): coords = {} for dim_name, size in dim_sizes.items(): if dim_name == "time": dtype = np.int64 + step = timestep_size + size = size * step attrs = {"calendar": calendar, "units": "days since 1970-01-01"} else: dtype = np.float32 + step = 1 attrs = {} - coord_value = np.arange(size, dtype=dtype) + coord_value = np.arange(0, size, step, dtype=dtype) coord = xr.DataArray(coord_value, dims=(dim_name,), attrs=attrs) coords[dim_name] = coord return coords -def _save_netcdf(filename, dim_sizes, variable_names, calendar): +def _save_netcdf(filename, dim_sizes, variable_names, calendar, timestep_size=1): data_vars = {} for name in variable_names: - data = np.random.randn(*list(dim_sizes.values())) + if name == "constant_mask": + data = np.ones(list(dim_sizes.values())) + else: + data = np.random.randn(*list(dim_sizes.values())) if len(dim_sizes) > 0: data = data.astype(np.float32) # type: ignore data_vars[name] = xr.DataArray( data, dims=list(dim_sizes), attrs={"units": "m", "long_name": name} ) - coords = _get_coords(dim_sizes, calendar) + coords = _get_coords(dim_sizes, calendar, timestep_size) for i in range(7): data_vars[f"ak_{i}"] = float(i) data_vars[f"bk_{i}"] = float(i + 1) ds = xr.Dataset(data_vars=data_vars, coords=coords) ds.to_netcdf(filename, unlimited_dims=["time"], format="NETCDF4_CLASSIC") + return ds def _create_dataset_on_disk( @@ -68,6 +81,7 @@ def _create_dataset_on_disk( calendar: str = "proleptic_gregorian", data_dim_sizes=None, n_times: int = 3, + timestep_size: int = 1, ) -> pathlib.Path: if data_dim_sizes is None: data_dim_sizes = {"time": n_times, "grid_yt": 16, "grid_xt": 32} @@ -76,10 +90,14 @@ def _create_dataset_on_disk( in_variable_names = ["foo", "bar", "baz"] out_variable_names = ["foo", "bar"] mask_name = "mask" - all_variable_names = list(set(in_variable_names + out_variable_names)) + [mask_name] + constant_mask_name = "constant_mask" + all_variable_names = list(set(in_variable_names + out_variable_names)) + [ + mask_name, + constant_mask_name, + ] data_path = data_dir / "data.nc" - _save_netcdf(data_path, data_dim_sizes, all_variable_names, calendar) + _save_netcdf(data_path, data_dim_sizes, all_variable_names, calendar, timestep_size) return data_path @@ -108,8 +126,8 @@ def test_ensemble_loader(tmp_path, num_ensemble_members=3): samples_per_member = n_timesteps - window_timesteps + 1 data = get_data_loader(config, True, requirements) - assert len(data.loader) == samples_per_member * num_ensemble_members - assert isinstance(data.sigma_coordinates, SigmaCoordinates) + assert data.n_batches == samples_per_member * num_ensemble_members + assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate) def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1): @@ -139,12 +157,12 @@ def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1 requirements = DataRequirements(["foo"], window_timesteps) data = get_data_loader(config, True, requirements) - assert len(data.loader) == n_samples * num_ensemble_members - assert isinstance(data.sigma_coordinates, SigmaCoordinates) + assert data.n_batches == n_samples * num_ensemble_members + assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate) def test_xarray_loader(tmp_path): - """Checks that sigma coordinates are present.""" + """Checks that vertical coordinates are present.""" _create_dataset_on_disk(tmp_path) config = DataLoaderConfig( [XarrayDataConfig(data_path=tmp_path, n_repeats=1)], @@ -154,13 +172,14 @@ def test_xarray_loader(tmp_path): window_timesteps = 2 # 1 initial condition and 1 step forward requirements = DataRequirements(["foo"], window_timesteps) data = get_data_loader(config, True, requirements) # type: ignore - assert isinstance(data.sigma_coordinates, SigmaCoordinates) + assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate) + assert data.vertical_coordinate.ak.device == fme.get_device() def test_xarray_loader_hpx(tmp_path): - """Checks that sigma coordinates are present.""" + """Checks that vertical coordinates are present.""" n_times = 3 - data_dim_sizes = {"time": n_times, "face": 12, "width": 64, "height": 64} + data_dim_sizes = {"time": n_times, "face": 12, "width": 16, "height": 16} _create_dataset_on_disk(tmp_path, data_dim_sizes=data_dim_sizes, n_times=n_times) config = DataLoaderConfig( [ @@ -177,9 +196,10 @@ def test_xarray_loader_hpx(tmp_path): for batch in data.loader: assert batch is not None # expect healpix shape - assert batch.data["foo"].shape == (1, window_timesteps, 12, 64, 64) + assert batch.data["foo"].shape == (1, window_timesteps, 12, 16, 16) break - assert isinstance(data.sigma_coordinates, SigmaCoordinates) + assert isinstance(data.vertical_coordinate, HybridSigmaPressureCoordinate) + assert data.vertical_coordinate.ak.device == fme.get_device() def test_loader_n_repeats_but_not_infer_timestep_error(tmp_path): @@ -206,25 +226,42 @@ def test_inference_data_loader(tmp_path): ), ) n_forward_steps_in_memory = 3 - requirements = DataRequirements(["foo"], n_timesteps=7) - data_loader = get_inference_data( + window_requirements = DataRequirements( + names=["foo", "bar"], + n_timesteps=n_forward_steps_in_memory + 1, + ) + initial_condition_requirements = PrognosticStateDataRequirements( + names=["foo"], + n_timesteps=1, + ) + data = get_inference_data( config, - forward_steps_in_memory=n_forward_steps_in_memory, - requirements=requirements, - ).loader + total_forward_steps=6, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, + ) + data_loader = data.loader batch_data = next(iter(data_loader)) assert isinstance(batch_data, BatchData) - assert isinstance(batch_data.data["foo"], torch.Tensor) - assert batch_data.data["foo"].shape[0] == batch_size - assert batch_data.data["foo"].shape[1] == n_forward_steps_in_memory + 1 - assert batch_data.data["foo"].shape[2] == 16 - assert batch_data.data["foo"].shape[3] == 32 - assert isinstance(batch_data.times, xr.DataArray) - assert list(batch_data.times.dims) == ["sample", "time"] - assert batch_data.times.sizes["sample"] == batch_size - assert batch_data.times.sizes["time"] == n_forward_steps_in_memory + 1 - assert batch_data.times.dt.calendar == "proleptic_gregorian" - assert len(data_loader) == 2 + for name in ["foo", "bar"]: + assert isinstance(batch_data.data[name], torch.Tensor) + assert batch_data.data[name].shape == ( + batch_size, + n_forward_steps_in_memory + 1, + 16, + 32, + ) + assert isinstance(batch_data.time, xr.DataArray) + assert list(batch_data.time.dims) == ["sample", "time"] + assert batch_data.time.sizes["sample"] == batch_size + assert batch_data.time.sizes["time"] == n_forward_steps_in_memory + 1 + assert batch_data.time.dt.calendar == "proleptic_gregorian" + assert data._n_batches == 2 + assert data.vertical_coordinate.ak.device == fme.get_device() + initial_condition = data.initial_condition.as_batch_data() + assert isinstance(initial_condition, BatchData) + assert "bar" not in initial_condition.data + assert initial_condition.data["foo"].shape == (batch_size, 1, 16, 32) @pytest.fixture(params=["julian", "proleptic_gregorian", "noleap"]) @@ -255,26 +292,11 @@ def test_data_loader_outputs(tmp_path, calendar): assert isinstance(batch_data, BatchData) assert isinstance(batch_data.data["foo"], torch.Tensor) assert batch_data.data["foo"].shape[0] == n_samples - assert isinstance(batch_data.times, xr.DataArray) - assert list(batch_data.times.dims) == ["sample", "time"] - assert batch_data.times.sizes["sample"] == n_samples - assert batch_data.times.sizes["time"] == window_timesteps - assert batch_data.times.dt.calendar == calendar - - -def test_get_times_non_cftime(): - """ - Check that `get_times` raises an error when the time coordinate is not - cftime.datetime - """ - n_times = 5 - times = [datetime.datetime(2020, 1, 1, i, 0, 0) for i in range(n_times)] - ds = xr.Dataset( - {"foo": xr.DataArray(np.arange(n_times), dims=("time",))}, - coords={"time": times}, - ) - with pytest.raises(AssertionError): - get_times(ds, 0, 1) + assert isinstance(batch_data.time, xr.DataArray) + assert list(batch_data.time.dims) == ["sample", "time"] + assert batch_data.time.sizes["sample"] == n_samples + assert batch_data.time.sizes["time"] == window_timesteps + assert batch_data.time.dt.calendar == calendar @pytest.mark.parametrize( @@ -320,20 +342,29 @@ def test_inference_data_loader_validate_n_forward_steps( start_indices=start_indices, ) n_forward_steps_in_memory = num_forward_steps - requirements = DataRequirements(["foo"], n_timesteps=num_forward_steps + 1) + window_requirements = DataRequirements( + names=["foo", "bar"], + n_timesteps=n_forward_steps_in_memory + 1, + ) + initial_condition_requirements = PrognosticStateDataRequirements( + names=["foo"], + n_timesteps=1, + ) if raises_error: with pytest.raises(ValueError): get_inference_data( config, - forward_steps_in_memory=n_forward_steps_in_memory, - requirements=requirements, + total_forward_steps=num_forward_steps, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, ) else: get_inference_data( config, - forward_steps_in_memory=n_forward_steps_in_memory, - requirements=requirements, + total_forward_steps=num_forward_steps, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, ) @@ -368,27 +399,42 @@ def test_get_forcing_data(tmp_path, n_initial_conditions): forward_steps_in_memory = 2 _create_dataset_on_disk(tmp_path, calendar=calendar, n_times=10) config = ForcingDataLoaderConfig(XarrayDataConfig(data_path=tmp_path)) - requirements = DataRequirements(["foo"], total_forward_steps + 1) + window_requirements = DataRequirements( + names=["foo"], + n_timesteps=forward_steps_in_memory + 1, + ) time_values = [ - cftime.datetime(1970, 1, 1 + 2 * n, calendar=calendar) + [cftime.datetime(1970, 1, 1 + 2 * n, calendar=calendar)] for n in range(n_initial_conditions) ] - initial_times = xr.DataArray(time_values, dims=["sample"]) - data = get_forcing_data( - config, forward_steps_in_memory, requirements, initial_times + initial_condition = BatchData.new_on_cpu( + data={"foo": torch.randn(n_initial_conditions, 1, 1, 1)}, + time=xr.DataArray(time_values, dims=["sample", "time"]), ) - assert len(data.loader.dataset) == math.ceil( - total_forward_steps / forward_steps_in_memory + data = get_forcing_data( + config, + total_forward_steps, + window_requirements=window_requirements, + initial_condition=PrognosticState(initial_condition), ) + assert data._n_samples == math.ceil(total_forward_steps / forward_steps_in_memory) batch_data = next(iter(data.loader)) assert isinstance(batch_data, BatchData) assert isinstance(batch_data.data["foo"], torch.Tensor) assert set(batch_data.data.keys()) == {"foo"} assert batch_data.data["foo"].shape[0] == len(time_values) assert batch_data.data["foo"].shape[1] == forward_steps_in_memory + 1 - assert list(batch_data.times.dims) == ["sample", "time"] - xr.testing.assert_allclose(batch_data.times.isel(time=0), initial_times) - assert batch_data.times.dt.calendar == calendar + assert list(batch_data.time.dims) == ["sample", "time"] + xr.testing.assert_allclose(batch_data.time[:, 0], initial_condition.time[:, 0]) + assert batch_data.time.dt.calendar == calendar + xr.testing.assert_equal( + data.initial_condition.as_batch_data().time, + initial_condition.time, + ) + np.testing.assert_allclose( + data.initial_condition.as_batch_data().data["foo"].cpu().numpy(), + initial_condition.data["foo"].cpu().numpy(), + ) def test_inference_loader_raises_if_subset(): @@ -446,3 +492,77 @@ def test_TimestampList_as_indices(timestamps, expected_indices): np.testing.assert_equal( timestamp_list.as_indices(time_index), np.array(expected_indices) ) + + +def test_inference_data_with_perturbations(tmp_path): + _create_dataset_on_disk(tmp_path, n_times=14) + batch_size = 1 + step = 7 + config = InferenceDataLoaderConfig( + XarrayDataConfig( + data_path=tmp_path, + n_repeats=1, + ), + start_indices=InferenceInitialConditionIndices( + first=0, n_initial_conditions=batch_size, interval=step + ), + perturbations=SSTPerturbation( + sst=[PerturbationSelector(type="constant", config={"amplitude": 2.0})] + ), + ) + n_forward_steps_in_memory = 3 + original_foo = xr.open_dataset(os.path.join(tmp_path, "data.nc"))["foo"].values[ + 0 : n_forward_steps_in_memory + 1, :, : + ] + window_requirements = DataRequirements( + names=["foo", "constant_mask"], + n_timesteps=n_forward_steps_in_memory + 1, + ) + initial_condition_requirements = PrognosticStateDataRequirements( + names=["foo"], + n_timesteps=1, + ) + data = get_inference_data( + config, + total_forward_steps=6, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, + surface_temperature_name="foo", + ocean_fraction_name="constant_mask", + ) + batch_data = next(iter(data.loader)) + np.testing.assert_allclose( + original_foo + 2.0, + batch_data.data["foo"].cpu().numpy()[0, :, :, :], + ) + np.testing.assert_allclose( + original_foo[:1, :, :] + 2.0, + data.initial_condition.as_batch_data().data["foo"].cpu().numpy()[0, :, :, :], + ) + + +def test_inference_persistence_names(tmp_path): + _create_dataset_on_disk(tmp_path, n_times=14) + + config = InferenceDataLoaderConfig( + XarrayDataConfig(data_path=tmp_path), + start_indices=ExplicitIndices([0, 3]), + persistence_names=["foo"], + ) + window_requirements = DataRequirements( + names=["foo", "bar"], + n_timesteps=3, + ) + dataset = InferenceDataset( + config, + 9, + requirements=window_requirements, + ) + first_item = dataset[0].data + second_item = dataset[1].data + # ensure first and second time steps are the same + torch.testing.assert_close(first_item["foo"][:, 0], first_item["foo"][:, 1]) + # ensure the entire first and second returned items + torch.testing.assert_close(first_item["foo"], second_item["foo"]) + # ensure this is not the case for another variable + assert not torch.all(first_item["bar"] == second_item["bar"]) diff --git a/fme/fme/ace/data_loading/test_data_loading_config.py b/fme/fme/ace/data_loading/test_data_loading_config.py new file mode 100644 index 0000000..cfbd067 --- /dev/null +++ b/fme/fme/ace/data_loading/test_data_loading_config.py @@ -0,0 +1,69 @@ +from datetime import timedelta + +import numpy as np +import pandas as pd +import pytest + +from fme.core.dataset.config import RepeatedInterval + + +def test_repeated_interval_int(): + interval = RepeatedInterval(interval_length=3, block_length=6, start=0) + mask = interval.get_boolean_mask(length=18) + expected_mask = np.array([True, True, True, False, False, False] * 3) + np.testing.assert_array_equal(mask, expected_mask) + + +def test_repeated_interval_str(): + interval = RepeatedInterval(interval_length="1d", block_length="7d", start="2d") + mask = interval.get_boolean_mask(length=21, timestep=timedelta(days=1)) + expected_mask = np.array([False, False, True, False, False, False, False] * 3) + np.testing.assert_array_equal(mask, expected_mask) + + +def test_repeated_interval_mixed_types(): + with pytest.raises(ValueError): + RepeatedInterval(interval_length=3, block_length="6d", start=0) + + +@pytest.mark.parametrize("interval, block, start", [(4, 6, 3), ("2d", "3d", "2d")]) +def test_repeated_interval_invalid_interval_start(interval, block, start): + """start + interval exceeds length of block""" + interval = RepeatedInterval( + interval_length=interval, block_length=block, start=start + ) + with pytest.raises(ValueError): + interval.get_boolean_mask(length=18, timestep=timedelta(days=1)) + + +def test_repeated_interval_zero_length(): + interval = RepeatedInterval(interval_length=0, block_length=6, start=0) + mask = interval.get_boolean_mask(length=18) + expected_mask = np.array([False] * 18) + np.testing.assert_array_equal(mask, expected_mask) + + +def test_repeated_interval_partial_block(): + interval = RepeatedInterval(interval_length=3, block_length=6, start=0) + mask = interval.get_boolean_mask(length=20) + expected_mask = np.array([True, True, True, False, False, False] * 3 + [True, True]) + np.testing.assert_array_equal(mask, expected_mask) + + +def test_repeated_interval_no_timestep_fails_for_timedelta_lengths(): + interval = RepeatedInterval(interval_length="1d", block_length="7d", start="0d") + with pytest.raises(ValueError): + interval.get_boolean_mask(length=20) + + +@pytest.mark.parametrize("timestep", ["2h", "150m", "5h", "12h"]) +def test_invalid_timesteps(timestep): + """ + Test that timesteps that don't evenly divide into some or all + arguments raise a ValueError + """ + timestep = pd.to_timedelta(timestep) + with pytest.raises(ValueError): + RepeatedInterval( + interval_length="5h", start="4h", block_length="10h" + ).get_boolean_mask(length=18, timestep=timestep) diff --git a/fme/fme/core/data_loading/test_metadata.py b/fme/fme/ace/data_loading/test_metadata.py similarity index 57% rename from fme/fme/core/data_loading/test_metadata.py rename to fme/fme/ace/data_loading/test_metadata.py index 702e7a3..19f4e94 100644 --- a/fme/fme/core/data_loading/test_metadata.py +++ b/fme/fme/ace/data_loading/test_metadata.py @@ -6,32 +6,11 @@ import pytest import xarray as xr -from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig -from fme.core.data_loading.data_typing import VariableMetadata -from fme.core.data_loading.getters import get_data_loader -from fme.core.data_loading.requirements import DataRequirements - -METADATA = [ - pytest.param( - {"bar": None}, - id="one_var_no_metadata", - ), - pytest.param( - {"bar": VariableMetadata("km", "bar_long_name")}, - id="one_var_metadata", - ), - pytest.param( - {"foo": VariableMetadata("m", "foo_long_name"), "bar": None}, - id="two_vars_one_metadata", - ), - pytest.param( - { - "foo": VariableMetadata("m", "foo_long_name"), - "bar": VariableMetadata("km", "bar_long_name"), - }, - id="two_vars_two_metadata", - ), -] +from fme.ace.data_loading.config import DataLoaderConfig +from fme.ace.data_loading.getters import get_data_loader +from fme.core.dataset.config import XarrayDataConfig +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.dataset.requirements import DataRequirements def _coord_value(name, size): @@ -47,18 +26,18 @@ def _coord_value(name, size): def _save_netcdf( filename, - metadata: Mapping[str, Optional[VariableMetadata]], + variable_metadata: Mapping[str, Optional[VariableMetadata]], num_members=1, dim_sizes=None, ): if dim_sizes is None: dim_sizes = {"time": 3, "grid_yt": 16, "grid_xt": 32} data_vars = {} - for name in metadata: + for name in variable_metadata: data = np.random.randn(*list(dim_sizes.values())) if len(dim_sizes) > 0: data = data.astype(np.float32) - item_metadata = metadata[name] + item_metadata = variable_metadata[name] if item_metadata is None: attrs = {} else: @@ -84,24 +63,49 @@ def _save_netcdf( @pytest.mark.parametrize("n_ensemble_members", [1, 2]) -@pytest.mark.parametrize("metadata", METADATA) -def test_metadata(tmp_path, metadata, n_ensemble_members): +@pytest.mark.parametrize( + "variable_metadata", + [ + pytest.param( + {"bar": None}, + id="one_var_no_metadata", + ), + pytest.param( + {"bar": VariableMetadata("km", "bar_long_name")}, + id="one_var_metadata", + ), + pytest.param( + {"foo": VariableMetadata("m", "foo_long_name"), "bar": None}, + id="two_vars_one_metadata", + ), + pytest.param( + { + "foo": VariableMetadata("m", "foo_long_name"), + "bar": VariableMetadata("km", "bar_long_name"), + }, + id="two_vars_two_metadata", + ), + ], +) +def test_metadata(tmp_path, variable_metadata, n_ensemble_members): paths = [] for i in range(n_ensemble_members): path = tmp_path / f"ic{i}" path.mkdir(exist_ok=True) paths.append(path) - _save_netcdf(path / "data.nc", metadata) + _save_netcdf(path / "data.nc", variable_metadata) config = DataLoaderConfig( [XarrayDataConfig(data_path=str(path)) for path in paths], batch_size=1, num_data_workers=0, ) - var_names = list(metadata.keys()) + var_names = list(variable_metadata.keys()) requirements = DataRequirements(names=var_names, n_timesteps=2) data = get_data_loader(config=config, train=True, requirements=requirements) target_metadata = { - name: metadata[name] for name in metadata if metadata[name] is not None + name: variable_metadata[name] + for name in variable_metadata + if variable_metadata[name] is not None } - assert data.metadata == target_metadata # type: ignore + assert data.variable_metadata == target_metadata # type: ignore diff --git a/fme/fme/ace/data_loading/test_perturbation.py b/fme/fme/ace/data_loading/test_perturbation.py new file mode 100644 index 0000000..296d14b --- /dev/null +++ b/fme/fme/ace/data_loading/test_perturbation.py @@ -0,0 +1,54 @@ +import torch + +import fme +from fme.ace.data_loading.perturbation import ( + ConstantConfig, + GreensFunctionConfig, + PerturbationSelector, +) + + +def test_constant_perturbation_config(): + selector = PerturbationSelector( + type="constant", + config={"amplitude": 1.0}, + ) + perturbation = selector.build() + assert isinstance(perturbation, ConstantConfig) + assert perturbation.amplitude == 1.0 + nx, ny = 5, 5 + lat = torch.arange(nx, device=fme.get_device()) + lon = torch.arange(ny, device=fme.get_device()) + lats, lons = torch.meshgrid(lat, lon, indexing="ij") + ocean_fraction = torch.ones(nx, ny, device=fme.get_device()) + data = torch.ones(nx, ny, device=fme.get_device()) + expected = 2.0 * torch.ones(nx, ny, device=fme.get_device()) + perturbation.apply_perturbation(data, lats, lons, ocean_fraction) + torch.testing.assert_close(data, expected) + + +def test_green_function_perturbation_config(): + selector = PerturbationSelector( + type="greens_function", + config={ + "amplitude": 1.0, + "lat_center": 0.0, + "lon_center": 0.0, + "lat_width": 10.0, + "lon_width": 10.0, + }, + ) + perturbation = selector.build() + assert isinstance(perturbation, GreensFunctionConfig) + assert perturbation.amplitude == 1.0 + assert perturbation.lat_center == 0.0 + assert perturbation.lon_center == 0.0 + assert perturbation.lat_width == 10.0 + assert perturbation.lon_width == 10.0 + nx, ny = 5, 5 + lat = torch.arange(nx, device=fme.get_device()) + lon = torch.arange(ny, device=fme.get_device()) + lats, lons = torch.meshgrid(lat, lon, indexing="ij") + ocean_fraction = torch.ones(nx, ny, device=fme.get_device()) + data = torch.ones(nx, ny, device=fme.get_device()) + perturbation.apply_perturbation(data, lats, lons, ocean_fraction) diff --git a/fme/fme/ace/inference/__init__.py b/fme/fme/ace/inference/__init__.py index 20dba26..e69de29 100644 --- a/fme/fme/ace/inference/__init__.py +++ b/fme/fme/ace/inference/__init__.py @@ -1 +0,0 @@ -from .loop import run_inference_evaluator diff --git a/fme/fme/ace/inference/__main__.py b/fme/fme/ace/inference/__main__.py index 8825044..e9ccddd 100644 --- a/fme/fme/ace/inference/__main__.py +++ b/fme/fme/ace/inference/__main__.py @@ -4,5 +4,13 @@ parser = argparse.ArgumentParser() parser.add_argument("yaml_config", type=str) +parser.add_argument( + "--segments", + type=int, + default=None, + help="If provided, number of times to repeat the inference in time, saving each " + "segment in a separate folder labeled as 'segment_0000', 'segment_0001' etc. " + "WARNING: this feature is experimental and its API is subject to change.", +) args = parser.parse_args() -main(yaml_config=args.yaml_config) +main(yaml_config=args.yaml_config, segments=args.segments) diff --git a/fme/fme/ace/inference/data_writer/__init__.py b/fme/fme/ace/inference/data_writer/__init__.py index 1749072..641c6f8 100644 --- a/fme/fme/ace/inference/data_writer/__init__.py +++ b/fme/fme/ace/inference/data_writer/__init__.py @@ -1 +1 @@ -from .main import DataWriter, DataWriterConfig, NullDataWriter, PairedDataWriter +from .main import DataWriter, DataWriterConfig, PairedDataWriter diff --git a/fme/fme/ace/inference/data_writer/histograms.py b/fme/fme/ace/inference/data_writer/histograms.py index 7d2cc5b..c87057c 100644 --- a/fme/fme/ace/inference/data_writer/histograms.py +++ b/fme/fme/ace/inference/data_writer/histograms.py @@ -5,7 +5,7 @@ import torch import xarray as xr -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata from fme.core.histogram import DynamicHistogram @@ -73,21 +73,21 @@ def __init__( self, path: str, n_timesteps: int, - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], save_names: Optional[Sequence[str]], ): self._target_writer = HistogramDataWriter( path=path, n_timesteps=n_timesteps, filename="histograms_target.nc", - metadata=metadata, + variable_metadata=variable_metadata, save_names=save_names, ) self._prediction_writer = HistogramDataWriter( path=path, n_timesteps=n_timesteps, filename="histograms_prediction.nc", - metadata=metadata, + variable_metadata=variable_metadata, save_names=save_names, ) @@ -96,17 +96,17 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): self._target_writer.append_batch( data=target, start_timestep=start_timestep, - batch_times=batch_times, + batch_time=batch_time, ) self._prediction_writer.append_batch( data=prediction, start_timestep=start_timestep, - batch_times=batch_times, + batch_time=batch_time, ) def flush(self): @@ -124,36 +124,37 @@ def __init__( path: str, n_timesteps: int, filename: str, - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], save_names: Optional[Sequence[str]], ): """ Args: - path: Path to write netCDF file(s). + path: The directory within which to write the file. n_timesteps: Number of timesteps to write to the file. - metadata: Metadata for each variable to be written to the file. + filename: Name of the file to write. + variable_metadata: Metadata for each variable to be written to the file. + save_names: Names of variables to save. If None, all variables are saved. """ self.path = path self._metrics_filename = str(Path(path) / filename) - self.metadata = metadata + self.variable_metadata = variable_metadata self._histogram = _HistogramAggregator(n_times=n_timesteps, names=save_names) def append_batch( self, data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): """ Append a batch of data to the file. Args: - target: Target data. - prediction: Prediction data. + data: The data to write. start_timestep: Timestep at which to start writing. - batch_times: Time coordinates for each sample in the batch. + batch_time: Time coordinate for each sample in the batch. """ - del batch_times + del batch_time self._histogram.record_batch( data=data, i_time_start=start_timestep, @@ -164,11 +165,11 @@ def flush(self): Flush the data to disk. """ metric_dataset = self._histogram.get_dataset() - for name in self.metadata: + for name in self.variable_metadata: try: - metric_dataset[f"{name}_bin_edges"].attrs["units"] = self.metadata[ - name - ].units + metric_dataset[f"{name}_bin_edges"].attrs["units"] = ( + self.variable_metadata[name].units + ) except KeyError: logging.info( f"{name} in metadata but not in data written to " @@ -176,9 +177,9 @@ def flush(self): ) for name in metric_dataset.data_vars: if not name.endswith("_bin_edges"): - metric_dataset[f"{name}_bin_edges"].attrs[ - "long_name" - ] = f"{name} bin edges" + metric_dataset[f"{name}_bin_edges"].attrs["long_name"] = ( + f"{name} bin edges" + ) metric_dataset[name].attrs["units"] = "count" metric_dataset[name].attrs["long_name"] = f"{name} histogram" metric_dataset.to_netcdf(self._metrics_filename) diff --git a/fme/fme/ace/inference/data_writer/main.py b/fme/fme/ace/inference/data_writer/main.py index b70aa1c..df908dd 100644 --- a/fme/fme/ace/inference/data_writer/main.py +++ b/fme/fme/ace/inference/data_writer/main.py @@ -2,18 +2,19 @@ import datetime import warnings from pathlib import Path -from typing import Dict, List, Mapping, Optional, Sequence, Union +from typing import List, Mapping, Optional, Sequence, Union import numpy as np import torch import xarray as xr -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.generics.writer import WriterABC from .histograms import PairedHistogramDataWriter from .monthly import MonthlyDataWriter, PairedMonthlyDataWriter, months_for_timesteps from .raw import PairedRawDataWriter, RawDataWriter -from .restart import PairedRestartWriter, RestartWriter from .time_coarsen import PairedTimeCoarsen, TimeCoarsen, TimeCoarsenConfig from .video import PairedVideoDataWriter @@ -23,10 +24,9 @@ PairedHistogramDataWriter, PairedTimeCoarsen, PairedMonthlyDataWriter, - PairedRestartWriter, ] -Subwriter = Union[MonthlyDataWriter, RawDataWriter, RestartWriter, TimeCoarsen] +Subwriter = Union[MonthlyDataWriter, RawDataWriter, TimeCoarsen] @dataclasses.dataclass @@ -34,7 +34,7 @@ class DataWriterConfig: """ Configuration for inference data writers. - Attributes: + Parameters: log_extended_video_netcdfs: Whether to enable writing of netCDF files containing video metrics. save_prediction_files: Whether to enable writing of netCDF files @@ -73,25 +73,23 @@ def __post_init__(self): def build_paired( self, experiment_dir: str, - n_samples: int, + n_initial_conditions: int, n_timesteps: int, timestep: datetime.timedelta, - prognostic_names: Sequence[str], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ) -> "PairedDataWriter": return PairedDataWriter( path=experiment_dir, - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, n_timesteps=n_timesteps, timestep=timestep, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, enable_prediction_netcdfs=self.save_prediction_files, enable_monthly_netcdfs=self.save_monthly_files, enable_video_netcdfs=self.log_extended_video_netcdfs, save_names=self.names, - prognostic_names=prognostic_names, enable_histogram_netcdfs=self.save_histogram_files, time_coarsen=self.time_coarsen, ) @@ -99,11 +97,10 @@ def build_paired( def build( self, experiment_dir: str, - n_samples: int, + n_initial_conditions: int, n_timesteps: int, timestep: datetime.timedelta, - prognostic_names: Sequence[str], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ) -> "DataWriter": if self.save_histogram_files: @@ -118,43 +115,42 @@ def build( ) return DataWriter( path=experiment_dir, - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, n_timesteps=n_timesteps, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, timestep=timestep, enable_prediction_netcdfs=self.save_prediction_files, enable_monthly_netcdfs=self.save_monthly_files, save_names=self.names, - prognostic_names=prognostic_names, time_coarsen=self.time_coarsen, ) -class PairedDataWriter: +class PairedDataWriter(WriterABC[PrognosticState, PairedData]): def __init__( self, path: str, - n_samples: int, + n_initial_conditions: int, n_timesteps: int, - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], timestep: datetime.timedelta, enable_prediction_netcdfs: bool, enable_monthly_netcdfs: bool, enable_video_netcdfs: bool, save_names: Optional[Sequence[str]], - prognostic_names: Sequence[str], enable_histogram_netcdfs: bool, time_coarsen: Optional[TimeCoarsenConfig] = None, ): """ Args: path: Path to write netCDF file(s). - n_samples: Number of samples to write to the file. + n_initial_conditions: Number of ICs/ensemble members to write to the file. n_timesteps: Number of timesteps to write to the file. - metadata: Metadata for each variable to be written to the file. + variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. + timestep: Timestep of the model. enable_prediction_netcdfs: Whether to enable writing of netCDF files containing the predictions and target values. enable_monthly_netcdfs: Whether to enable writing of netCDF files @@ -169,8 +165,7 @@ def __init__( self._writers: List[PairedSubwriter] = [] self.path = path self.coords = coords - self.metadata = metadata - self.prognostic_names = prognostic_names + self.variable_metadata = variable_metadata if time_coarsen is not None: n_coarsened_timesteps = time_coarsen.n_coarsened_timesteps(n_timesteps) @@ -187,9 +182,9 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter: _time_coarsen_builder( PairedRawDataWriter( path=path, - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) ) @@ -198,11 +193,11 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter: self._writers.append( PairedMonthlyDataWriter( path=path, - n_samples=n_samples, + n_samples=n_initial_conditions, n_timesteps=n_timesteps, timestep=timestep, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) ) @@ -212,7 +207,7 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter: PairedVideoDataWriter( path=path, n_timesteps=n_coarsened_timesteps, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) ) @@ -223,66 +218,46 @@ def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter: PairedHistogramDataWriter( path=path, n_timesteps=n_coarsened_timesteps, - metadata=metadata, + variable_metadata=variable_metadata, save_names=save_names, ) ) ) - self._writers.append( - PairedRestartWriter( - path=path, - is_restart_step=lambda i: i == n_timesteps - 1, - prognostic_names=prognostic_names, - metadata=metadata, - coords=coords, - ) - ) + self._n_timesteps_seen = 0 - def save_initial_condition( - self, - ic_data: Dict[str, torch.Tensor], - ic_time: xr.DataArray, - ): - data_arrays = {} - for name in self.prognostic_names: - if name not in ic_data: - raise KeyError( - f"Initial condition data missing for prognostic variable {name}." - ) - data = ic_data[name].cpu().numpy() - data_arrays[name] = xr.DataArray(data, dims=["sample", "lat", "lon"]) - if name in self.metadata: - data_arrays[name].attrs = { - "long_name": self.metadata[name].long_name, - "units": self.metadata[name].units, - } - data_arrays["time"] = ic_time - ds = xr.Dataset(data_arrays, coords=self.coords) - ds.to_netcdf(str(Path(self.path) / "initial_condition.nc")) + def write(self, data: PrognosticState, filename: str): + """Eagerly write data to a single netCDF file. + + Args: + data: the data to be written. + filename: the filename to use for the netCDF file. + """ + _write( + data=data.as_batch_data(), + path=self.path, + filename=filename, + variable_metadata=self.variable_metadata, + coords=self.coords, + ) def append_batch( self, - target: Dict[str, torch.Tensor], - prediction: Dict[str, torch.Tensor], - start_timestep: int, - batch_times: xr.DataArray, + batch: PairedData, ): """ Append a batch of data to the file. Args: - target: Target data. - prediction: Prediction data. - start_timestep: Timestep at which to start writing. - batch_times: Time coordinates for each sample in the batch. + batch: Prediction and target data. """ for writer in self._writers: writer.append_batch( - target=target, - prediction=prediction, - start_timestep=start_timestep, - batch_times=batch_times, + target=dict(batch.target), + prediction=dict(batch.prediction), + start_timestep=self._n_timesteps_seen, + batch_time=batch.time, ) + self._n_timesteps_seen += batch.time.shape[1] def flush(self): """ @@ -292,27 +267,76 @@ def flush(self): writer.flush() -class DataWriter: +def _write( + data: BatchData, + path: str, + filename: str, + variable_metadata: Mapping[str, VariableMetadata], + coords: Mapping[str, np.ndarray], +): + """Write provided data to a single netCDF at specified path/filename. + + If the data has only one timestep, the data is squeezed to remove + the time dimension. + + Args: + data: Batch data to written. + path: Directory to write the netCDF file in. + filename: filename to use for netCDF. + variable_metadata: Metadata for each variable to be written to the file. + coords: Coordinate data to be written to the file. + """ + if data.time.sizes["time"] == 1: + time_dim = data.dims.index("time") + dims_to_write = data.dims[:time_dim] + data.dims[time_dim + 1 :] + + def maybe_squeeze(x: torch.Tensor) -> torch.Tensor: + return x.squeeze(dim=time_dim) + + time_array = data.time.isel(time=0) + else: + dims_to_write = data.dims + + def maybe_squeeze(x): + return x + + time_array = data.time + + data_arrays = {} + for name in data.data: + array = maybe_squeeze(data.data[name]).cpu().numpy() + data_arrays[name] = xr.DataArray(array, dims=dims_to_write) + if name in variable_metadata: + data_arrays[name].attrs = { + "long_name": variable_metadata[name].long_name, + "units": variable_metadata[name].units, + } + data_arrays["time"] = time_array + ds = xr.Dataset(data_arrays, coords=coords) + ds.to_netcdf(str(Path(path) / filename)) + + +class DataWriter(WriterABC[PrognosticState, BatchData]): def __init__( self, path: str, - n_samples: int, + n_initial_conditions: int, n_timesteps: int, - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], timestep: datetime.timedelta, enable_prediction_netcdfs: bool, enable_monthly_netcdfs: bool, save_names: Optional[Sequence[str]], - prognostic_names: Sequence[str], time_coarsen: Optional[TimeCoarsenConfig] = None, ): """ Args: path: Directory within which to write netCDF file(s). - n_samples: Number of samples to write to the file. + n_initial_conditions: Number of initial conditions / timeseries + to write to the file. n_timesteps: Number of timesteps to write to the file. - metadata: Metadata for each variable to be written to the file. + variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. timestep: Timestep of the model. enable_prediction_netcdfs: Whether to enable writing of netCDF files @@ -323,6 +347,10 @@ def __init__( time_coarsen: Configuration for time coarsening of raw outputs. """ self._writers: List[Subwriter] = [] + if "face" in coords: + # TODO: handle writing HEALPix data + # https://github.com/ai2cm/full-model/issues/1089 + return def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter: if time_coarsen is not None: @@ -335,9 +363,9 @@ def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter: RawDataWriter( path=path, label="autoregressive_predictions.nc", - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) ) @@ -348,40 +376,33 @@ def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter: MonthlyDataWriter( path=path, label="predictions", - n_samples=n_samples, + n_samples=n_initial_conditions, n_months=months_for_timesteps(n_timesteps, timestep), save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) ) - self._writers.append( - RestartWriter( - path=path, - is_restart_step=lambda i: i == n_timesteps - 1, - prognostic_names=prognostic_names, - metadata=metadata, - coords=coords, - ) - ) + self.path = path + self.variable_metadata = variable_metadata + self.coords = coords + self._n_timesteps_seen = 0 - def append_batch( - self, - data: Dict[str, torch.Tensor], - start_timestep: int, - batch_times: xr.DataArray, - ): + def append_batch(self, batch: BatchData): """ - Append a batch of data to the file. + Append prediction data to the file. + Args: - data: Data to write. - start_timestep: Timestep at which to start writing. - start_sample: Sample at which to start writing. - batch_times: Time coordinates for each sample in the batch. + batch: Data to be written. """ for writer in self._writers: - writer.append_batch(data, start_timestep, batch_times) + writer.append_batch( + data=dict(batch.data), + start_timestep=self._n_timesteps_seen, + batch_time=batch.time, + ) + self._n_timesteps_seen += batch.time.shape[1] def flush(self): """ @@ -390,30 +411,11 @@ def flush(self): for writer in self._writers: writer.flush() - -class NullDataWriter: - """ - Null pattern for DataWriter, which does nothing. - """ - - def __init__(self): - pass - - def append_batch( - self, - target: Dict[str, torch.Tensor], - prediction: Dict[str, torch.Tensor], - start_timestep: int, - batch_times: xr.DataArray, - ): - pass - - def flush(self): - pass - - def save_initial_condition( - self, - ic_data: Dict[str, torch.Tensor], - ic_time: xr.DataArray, - ): - pass + def write(self, data: PrognosticState, filename: str): + _write( + data=data.as_batch_data(), + path=self.path, + filename=filename, + variable_metadata=self.variable_metadata, + coords=self.coords, + ) diff --git a/fme/fme/ace/inference/data_writer/monthly.py b/fme/fme/ace/inference/data_writer/monthly.py index 6254f38..3b8f609 100644 --- a/fme/fme/ace/inference/data_writer/monthly.py +++ b/fme/fme/ace/inference/data_writer/monthly.py @@ -9,8 +9,12 @@ import xarray as xr from netCDF4 import Dataset -from fme.ace.inference.data_writer.utils import get_all_names -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.inference.data_writer.utils import ( + DIM_INFO_HEALPIX, + DIM_INFO_LATLON, + get_all_names, +) +from fme.core.dataset.data_typing import VariableMetadata LEAD_TIME_DIM = "time" LEAD_TIME_UNITS = "months" @@ -41,7 +45,7 @@ def __init__( n_timesteps: int, timestep: datetime.timedelta, save_names: Optional[Sequence[str]], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ): n_months = months_for_timesteps(n_timesteps, timestep) @@ -51,7 +55,7 @@ def __init__( n_samples=n_samples, n_months=n_months, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) self._prediction_writer = MonthlyDataWriter( @@ -60,7 +64,7 @@ def __init__( n_samples=n_samples, n_months=n_months, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) @@ -69,13 +73,13 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): self._target_writer.append_batch( - data=target, start_timestep=start_timestep, batch_times=batch_times + data=target, start_timestep=start_timestep, batch_time=batch_time ) self._prediction_writer.append_batch( - data=prediction, start_timestep=start_timestep, batch_times=batch_times + data=prediction, start_timestep=start_timestep, batch_time=batch_time ) def flush(self): @@ -97,7 +101,7 @@ def __init__( n_samples: int, n_months: int, save_names: Optional[Sequence[str]], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ): """ @@ -109,14 +113,14 @@ def __init__( n_months: Number of months to write to the file. save_names: Names of variables to save in the predictions netcdf file. If None, all predicted variables will be saved. - metadata: Metadata for each variable to be written to the file. + variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. """ if label != "": label = "_" + label filename = str(Path(path) / f"monthly_mean{label}.nc") self._save_names = save_names - self.metadata = metadata + self.variable_metadata = variable_metadata self.coords = coords self.dataset = Dataset(filename, "w", format="NETCDF4") self.dataset.createDimension(LEAD_TIME_DIM, n_months) @@ -137,10 +141,10 @@ def __init__( ) self.dataset.variables[VALID_TIME].units = TIME_UNITS self.dataset.variables[COUNTS][:] = 0 - self._n_lat: Optional[int] = None - self._n_lon: Optional[int] = None + self._init_years = np.full([n_samples], -1, dtype=int) self._init_months = np.full([n_samples], -1, dtype=int) + self._dataset_dims_created = False def _get_initial_year_and_month( self, @@ -165,7 +169,7 @@ def _get_initial_year_and_month( self.dataset.variables[VALID_TIME][:, :] = days_since_reference + 14 return (self._init_years, self._init_months) - def _get_month_indices(self, batch_times: xr.DataArray) -> np.ndarray: + def _get_month_indices(self, batch_time: xr.DataArray) -> np.ndarray: """ Get the month indices for the batch of data. @@ -174,16 +178,16 @@ def _get_month_indices(self, batch_times: xr.DataArray) -> np.ndarray: indices in this and future calls. Args: - batch_times: Time coordinates for each sample in the batch, of shape + batch_time: Time coordinate for each sample in the batch, of shape [ensemble_member, lead_time]. Returns: Month indices for the batch of data. """ - calendar = batch_times.dt.calendar - years = batch_times.dt.year.values + calendar = batch_time.dt.calendar + years = batch_time.dt.year.values # datetime months are 1-indexed, we want 0-indexed - months = batch_times.dt.month.values - 1 + months = batch_time.dt.month.values - 1 init_years, init_months = self._get_initial_year_and_month( years=years[:, 0], months=months[:, 0], calendar=calendar ) @@ -198,7 +202,7 @@ def append_batch( self, data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): """ Append a batch of data to the file. @@ -206,40 +210,39 @@ def append_batch( Args: data: Values to store. start_timestep: Timestep index for the start of the batch, unused. - batch_times: Time coordinates for each sample in the batch. + batch_time: Time coordinate for each sample in the batch. """ del start_timestep # unused n_samples_data = list(data.values())[0].shape[0] - n_samples_time = batch_times.sizes["sample"] + n_samples_time = batch_time.sizes["sample"] if n_samples_data != n_samples_time: raise ValueError( f"Batch size mismatch, data has {n_samples_data} samples " - f"and times has {n_samples_time} samples." + f"and batch_time has {n_samples_time} samples." ) n_times_data = list(data.values())[0].shape[1] - n_times_time = batch_times.sizes["time"] + n_times_time = batch_time.sizes["time"] if n_times_data != n_times_time: raise ValueError( f"Batch time dimension mismatch, data has {n_times_data} times " - f"and times has {n_times_time} times." + f"and batch_time has {n_times_time} times." ) - if self._n_lat is None: - self._n_lat = data[next(iter(data.keys()))].shape[-2] - self.dataset.createDimension("lat", self._n_lat) - if "lat" in self.coords: - self.dataset.createVariable("lat", "f4", ("lat",)) - self.dataset.variables["lat"][:] = self.coords["lat"] - if self._n_lon is None: - self._n_lon = data[next(iter(data.keys()))].shape[-1] - self.dataset.createDimension("lon", self._n_lon) - if "lon" in self.coords: - self.dataset.createVariable("lon", "f4", ("lon",)) - self.dataset.variables["lon"][:] = self.coords["lon"] - - dims = (ENSEMBLE_DIM, LEAD_TIME_DIM, "lat", "lon") + if not self._dataset_dims_created: + _dim_info = DIM_INFO_HEALPIX if "face" in self.coords else DIM_INFO_LATLON + _ordered_names = [] + for dim in _dim_info: + dim_size = data[next(iter(data.keys()))].shape[dim.index] + self.dataset.createDimension(dim.name, dim_size) + if dim.name in self.coords: + self.dataset.createVariable(dim.name, "f4", (dim.name,)) + self.dataset.variables[dim.name][:] = self.coords[dim.name] + _ordered_names.append(dim.name) + dims = (ENSEMBLE_DIM, LEAD_TIME_DIM, *_ordered_names) + self._dataset_dims_created = True + save_names = self._get_variable_names_to_save(data.keys()) - months = self._get_month_indices(batch_times) + months = self._get_month_indices(batch_time) month_min = np.min(months) month_range = np.max(months) - month_min + 1 count_data = self.dataset.variables[COUNTS][ @@ -255,13 +258,13 @@ def append_batch( fill_value=np.nan, ) self.dataset.variables[variable_name][:] = 0.0 - if variable_name in self.metadata: - self.dataset.variables[variable_name].units = self.metadata[ + if variable_name in self.variable_metadata: + self.dataset.variables[ variable_name - ].units - self.dataset.variables[variable_name].long_name = self.metadata[ + ].units = self.variable_metadata[variable_name].units + self.dataset.variables[ variable_name - ].long_name + ].long_name = self.variable_metadata[variable_name].long_name self.dataset.variables[variable_name].coordinates = " ".join( [INIT_TIME, VALID_TIME] ) @@ -385,5 +388,7 @@ def get_days_since_reference( freq="MS", calendar=calendar, ) - days_since_reference[i, :] = (dates_sample - reference_date).days + days_since_reference[i, :] = ( + dates_sample.values - reference_date + ) // datetime.timedelta(days=1) return days_since_reference diff --git a/fme/fme/ace/inference/data_writer/raw.py b/fme/fme/ace/inference/data_writer/raw.py index b4f37b8..5026957 100644 --- a/fme/fme/ace/inference/data_writer/raw.py +++ b/fme/fme/ace/inference/data_writer/raw.py @@ -9,12 +9,16 @@ import xarray as xr from netCDF4 import Dataset -from fme.ace.inference.data_writer.utils import get_all_names -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.inference.data_writer.utils import ( + DIM_INFO_HEALPIX, + DIM_INFO_LATLON, + get_all_names, +) +from fme.core.dataset.data_typing import VariableMetadata LEAD_TIME_DIM = "time" LEAD_TIME_UNITS = "microseconds" -SAMPLE_DIM = "sample" +IC_DIM = "sample" INIT_TIME = "init_time" INIT_TIME_UNITS = "microseconds since 1970-01-01 00:00:00" VALID_TIME = "valid_time" @@ -30,25 +34,25 @@ class PairedRawDataWriter: def __init__( self, path: str, - n_samples: int, + n_initial_conditions: int, save_names: Optional[Sequence[str]], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ): self._target_writer = RawDataWriter( path=path, label="autoregressive_target.nc", - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) self._prediction_writer = RawDataWriter( path=path, label="autoregressive_predictions.nc", - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, save_names=save_names, - metadata=metadata, + variable_metadata=variable_metadata, coords=coords, ) @@ -57,17 +61,17 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): self._target_writer.append_batch( data=target, start_timestep=start_timestep, - batch_times=batch_times, + batch_time=batch_time, ) self._prediction_writer.append_batch( data=prediction, start_timestep=start_timestep, - batch_times=batch_times, + batch_time=batch_time, ) def flush(self): @@ -84,36 +88,36 @@ def __init__( self, path: str, label: str, - n_samples: int, + n_initial_conditions: int, save_names: Optional[Sequence[str]], - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ): """ Args: - filename: Path to write netCDF file(s). - n_samples: Number of samples to write to the file. This might correspond - to a number of initial conditions, or some other grouping of samples. + path: Directory within which to write the file. + label: Name of the file to write. + n_initial_conditions: Number of initial conditions / timeseries + to write to the file. save_names: Names of variables to save in the output file. If None, all provided variables will be saved. - metadata: Metadata for each variable to be written to the file. + variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. """ filename = str(Path(path) / label) self._save_names = save_names - self.metadata = metadata + self.variable_metadata = variable_metadata self.coords = coords self.dataset = Dataset(filename, "w", format="NETCDF4") self.dataset.createDimension(LEAD_TIME_DIM, None) # unlimited dimension self.dataset.createVariable(LEAD_TIME_DIM, "i8", (LEAD_TIME_DIM,)) self.dataset.variables[LEAD_TIME_DIM].units = LEAD_TIME_UNITS - self.dataset.createDimension(SAMPLE_DIM, n_samples) - self.dataset.createVariable(INIT_TIME, "i8", (SAMPLE_DIM,)) + self.dataset.createDimension(IC_DIM, n_initial_conditions) + self.dataset.createVariable(INIT_TIME, "i8", (IC_DIM,)) self.dataset.variables[INIT_TIME].units = INIT_TIME_UNITS - self.dataset.createVariable(VALID_TIME, "i8", (SAMPLE_DIM, LEAD_TIME_DIM)) + self.dataset.createVariable(VALID_TIME, "i8", (IC_DIM, LEAD_TIME_DIM)) self.dataset.variables[VALID_TIME].units = INIT_TIME_UNITS - self._n_lat: Optional[int] = None - self._n_lon: Optional[int] = None + self._dataset_dims_created = False def _get_variable_names_to_save( self, *data_varnames: Iterable[str] @@ -124,7 +128,7 @@ def append_batch( self, data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): """ Append a batch of data to the file. @@ -132,37 +136,38 @@ def append_batch( Args: data: Data to be written to file. start_timestep: Timestep (lead time dim) at which to start writing. - batch_times: Time coordinates for each sample in the batch. + batch_time: Time coordinate for each sample in the batch. """ + if self.dataset is None: + return n_samples_data = list(data.values())[0].shape[0] - n_samples_time = batch_times.sizes["sample"] + n_samples_time = batch_time.sizes["sample"] if n_samples_data != n_samples_time: raise ValueError( f"Batch size mismatch, data has {n_samples_data} samples " f"and times has {n_samples_time} samples." ) n_times_data = list(data.values())[0].shape[1] - n_times_time = batch_times.sizes["time"] + n_times_time = batch_time.sizes["time"] if n_times_data != n_times_time: raise ValueError( f"Batch time dimension mismatch, data has {n_times_data} times " - f"and times has {n_times_time} times." + f"and time has {n_times_time} times." ) - if self._n_lat is None: - self._n_lat = data[next(iter(data.keys()))].shape[-2] - self.dataset.createDimension("lat", self._n_lat) - if "lat" in self.coords: - self.dataset.createVariable("lat", "f4", ("lat",)) - self.dataset.variables["lat"][:] = self.coords["lat"] - if self._n_lon is None: - self._n_lon = data[next(iter(data.keys()))].shape[-1] - self.dataset.createDimension("lon", self._n_lon) - if "lon" in self.coords: - self.dataset.createVariable("lon", "f4", ("lon",)) - self.dataset.variables["lon"][:] = self.coords["lon"] + if not self._dataset_dims_created: + _dim_info = DIM_INFO_HEALPIX if "face" in self.coords else DIM_INFO_LATLON + _ordered_names = [] + for dim in _dim_info: + dim_size = data[next(iter(data.keys()))].shape[dim.index] + self.dataset.createDimension(dim.name, dim_size) + if dim.name in self.coords: + self.dataset.createVariable(dim.name, "f4", (dim.name,)) + self.dataset.variables[dim.name][:] = self.coords[dim.name] + _ordered_names.append(dim.name) + dims = (IC_DIM, LEAD_TIME_DIM, *_ordered_names) + self._dataset_dims_created = True - dims = (SAMPLE_DIM, LEAD_TIME_DIM, "lat", "lon") save_names = self._get_variable_names_to_save(data.keys()) for variable_name in save_names: # define the variable if it doesn't exist @@ -173,13 +178,13 @@ def append_batch( dims, fill_value=np.nan, ) - if variable_name in self.metadata: - self.dataset.variables[variable_name].units = self.metadata[ + if variable_name in self.variable_metadata: + self.dataset.variables[ variable_name - ].units - self.dataset.variables[variable_name].long_name = self.metadata[ + ].units = self.variable_metadata[variable_name].units + self.dataset.variables[ variable_name - ].long_name + ].long_name = self.variable_metadata[variable_name].long_name self.dataset.variables[variable_name].coordinates = " ".join( [INIT_TIME, VALID_TIME] ) @@ -194,12 +199,12 @@ def append_batch( # handle time dimensions if not hasattr(self.dataset.variables[INIT_TIME], "calendar"): - self.dataset.variables[INIT_TIME].calendar = batch_times.dt.calendar + self.dataset.variables[INIT_TIME].calendar = batch_time.dt.calendar if not hasattr(self.dataset.variables[VALID_TIME], "calendar"): - self.dataset.variables[VALID_TIME].calendar = batch_times.dt.calendar + self.dataset.variables[VALID_TIME].calendar = batch_time.dt.calendar if start_timestep == 0: - init_times: np.ndarray = batch_times.isel(time=0).values + init_times: np.ndarray = batch_time.isel(time=0).values init_times_numeric: np.ndarray = cftime.date2num( init_times, units=self.dataset.variables[INIT_TIME].units, @@ -216,22 +221,22 @@ def append_batch( units=self.dataset.variables[INIT_TIME].units, calendar=self.dataset.variables[INIT_TIME].calendar, ) - lead_times_microseconds = get_batch_lead_times_microseconds( + lead_time_microseconds = get_batch_lead_time_microseconds( init_times, - batch_times.values, + batch_time.values, ) self.dataset.variables[LEAD_TIME_DIM][ - start_timestep : start_timestep + lead_times_microseconds.shape[0] - ] = lead_times_microseconds + start_timestep : start_timestep + lead_time_microseconds.shape[0] + ] = lead_time_microseconds valid_times_numeric: np.ndarray = cftime.date2num( - batch_times.values, + batch_time.values, units=self.dataset.variables[VALID_TIME].units, calendar=self.dataset.variables[VALID_TIME].calendar, ) self.dataset.variables[VALID_TIME][ :, - start_timestep : start_timestep + lead_times_microseconds.shape[0], + start_timestep : start_timestep + lead_time_microseconds.shape[0], ] = valid_times_numeric self.dataset.sync() # Flush the data to disk @@ -243,24 +248,24 @@ def flush(self): self.dataset.sync() -def get_batch_lead_times_microseconds( - init_times: npt.NDArray[cftime.datetime], batch_times: npt.NDArray[cftime.datetime] +def get_batch_lead_time_microseconds( + init_time: npt.NDArray[cftime.datetime], batch_time: npt.NDArray[cftime.datetime] ) -> npt.NDArray[np.int64]: """ - Get the lead times in seconds for the batch. + Get the lead time in seconds for the batch. Assert that they are the same for each sample. Args: - init_times: Initialization time for each sample in the batch. - batch_times: Full array of times for each sample in the batch. + init_time: Initialization time for each sample in the batch. + batch_time: Array of time coordinates for each sample in the batch. Returns: - Lead times in microseconds for the batch + Lead time in microseconds for the batch """ - if init_times.shape[0] != batch_times.shape[0]: + if init_time.shape[0] != batch_time.shape[0]: raise ValueError( - f"Number of init times ({len(init_times)}) must " - f"match number of batch times ({len(batch_times)})" + f"Number of init times ({len(init_time)}) must " + f"match number of batch times ({len(batch_time)})" ) # Carry out timedelta arithmetic in NumPy arrays to avoid xarray's automatic # casting of datetime.timedelta objects to timedelta64[ns] values, which would @@ -268,12 +273,12 @@ def get_batch_lead_times_microseconds( # ~292 years. See # https://numpy.org/doc/stable/reference/arrays.datetime.html#datetime-units # for more details on the limits of various precision timedeltas. - lead_times: npt.NDArray[datetime.timedelta] = ( # type: ignore - batch_times - init_times[:, None] + lead_time: npt.NDArray[datetime.timedelta] = ( # type: ignore + batch_time - init_time[:, None] ) - lead_times_microseconds: npt.NDArray[np.int64] = ( - lead_times // datetime.timedelta(microseconds=1) + lead_time_microseconds: npt.NDArray[np.int64] = ( + lead_time // datetime.timedelta(microseconds=1) ).astype(np.int64) - if not np.all(lead_times_microseconds == lead_times_microseconds[0, :]): + if not np.all(lead_time_microseconds == lead_time_microseconds[0, :]): raise ValueError("Lead times are not the same for each sample in the batch.") - return lead_times_microseconds[0, :] + return lead_time_microseconds[0, :] diff --git a/fme/fme/ace/inference/data_writer/test_data_writer.py b/fme/fme/ace/inference/data_writer/test_data_writer.py index 00b9286..da083d8 100644 --- a/fme/fme/ace/inference/data_writer/test_data_writer.py +++ b/fme/fme/ace/inference/data_writer/test_data_writer.py @@ -8,13 +8,15 @@ import xarray as xr from netCDF4 import Dataset +from fme.ace.data_loading.batch_data import BatchData, PairedData from fme.ace.inference.data_writer.main import ( DataWriter, DataWriterConfig, PairedDataWriter, ) -from fme.ace.inference.data_writer.raw import get_batch_lead_times_microseconds +from fme.ace.inference.data_writer.raw import get_batch_lead_time_microseconds from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig +from fme.core.device import get_device CALENDAR_CFTIME = { "julian": cftime.DatetimeJulian, @@ -34,8 +36,6 @@ def test_data_writer_config_save_names(): save_monthly_files=False, save_histogram_files=False, ) - with pytest.warns(): - DataWriterConfig(names=variable_names, **kwargs) for save_writer in [ "save_prediction_files", "save_monthly_files", @@ -43,7 +43,7 @@ def test_data_writer_config_save_names(): ]: kwargs_copy = kwargs.copy() kwargs_copy.update({save_writer: True}) - DataWriterConfig(names=variable_names, **kwargs_copy) + DataWriterConfig(names=variable_names, **kwargs_copy) # type: ignore class TestDataWriter: @@ -63,11 +63,13 @@ def calendar(self, request): """ return request.param - def get_batch_times(self, start_time, end_time, freq, n_samples, calendar="julian"): + def get_batch_time( + self, start_time, end_time, freq, n_initial_conditions, calendar="julian" + ): datetime_class = CALENDAR_CFTIME[calendar] start_time = datetime_class(*start_time) end_time = datetime_class(*end_time) - batch_times = xr.DataArray( + batch_time = xr.DataArray( xr.cftime_range( start_time, end_time, @@ -76,7 +78,9 @@ def get_batch_times(self, start_time, end_time, freq, n_samples, calendar="julia ).values, dims="time", ) - return xr.concat([batch_times for _ in range(n_samples)], dim="sample") + return xr.concat( + [batch_time for _ in range(n_initial_conditions)], dim="sample" + ) @pytest.fixture def sample_metadata(self): @@ -86,30 +90,54 @@ def sample_metadata(self): } @pytest.fixture - def sample_target_data(self): + def sample_target_data(self, request): + shape = request.param data = { - # sample, time, lat, lon - "temp": torch.rand((2, 3, 4, 5)), - "humidity": torch.rand((2, 3, 4, 5)), # input-only variable + # sample, time, *horizontal_dims + "temp": torch.rand(shape), + "humidity": torch.rand(shape), # input-only variable "pressure": torch.rand( - (2, 3, 4, 5) + shape ), # Extra variable for which there's no metadata - "precipitation": torch.rand((2, 3, 4, 5)), # optionally saved + "precipitation": torch.rand(shape), # optionally saved } return data @pytest.fixture - def sample_prediction_data(self): + def sample_prediction_data(self, request): + shape = request.param data = { # sample, time, lat, lon - "temp": torch.rand((2, 3, 4, 5)), + "temp": torch.rand(shape), "pressure": torch.rand( - (2, 3, 4, 5) + shape ), # Extra variable for which there's no metadata - "precipitation": torch.rand((2, 3, 4, 5)), # optionally saved + "precipitation": torch.rand(shape), # optionally saved } return data + @pytest.fixture + def coords(self, request): + return request.param + + @pytest.mark.parametrize( + "sample_target_data, sample_prediction_data, coords", + [ + pytest.param( + (2, 3, 4, 5), + (2, 3, 4, 5), + {"lat": np.arange(4), "lon": np.arange(5)}, + id="LatLon", + ), + pytest.param( + (2, 3, 6, 4, 5), + (2, 3, 6, 4, 5), + {"face": np.arange(6), "height": np.arange(4), "width": np.arange(5)}, + id="HEALPix", + ), + ], + indirect=True, + ) def test_append_batch( self, sample_metadata, @@ -117,52 +145,54 @@ def test_append_batch( sample_prediction_data, tmp_path, calendar, + coords, ): - n_samples = 2 + n_initial_conditions = 2 n_timesteps = 6 writer = PairedDataWriter( str(tmp_path), - n_samples=n_samples, + n_initial_conditions=n_initial_conditions, n_timesteps=n_timesteps, timestep=TIMESTEP, - metadata=sample_metadata, - coords={"lat": np.arange(4), "lon": np.arange(5)}, + variable_metadata=sample_metadata, + coords=coords, enable_prediction_netcdfs=True, enable_video_netcdfs=False, enable_monthly_netcdfs=True, enable_histogram_netcdfs=True, save_names=None, - prognostic_names=[], ) start_time = (2020, 1, 1, 0, 0, 0) end_time = (2020, 1, 1, 12, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time, end_time=end_time, - freq="6H", - n_samples=n_samples, + freq="6h", + n_initial_conditions=n_initial_conditions, calendar=calendar, ) writer.append_batch( - sample_target_data, - sample_prediction_data, - start_timestep=0, - batch_times=batch_times, + batch=PairedData( + prediction=sample_prediction_data, + target=sample_target_data, + time=batch_time, + ), ) start_time_2 = (2020, 1, 1, 18, 0, 0) end_time_2 = (2020, 1, 2, 6, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time_2, end_time=end_time_2, - freq="6H", - n_samples=n_samples, + freq="6h", + n_initial_conditions=n_initial_conditions, calendar=calendar, ) writer.append_batch( - sample_target_data, - sample_prediction_data, - start_timestep=3, - batch_times=batch_times, + batch=PairedData( + prediction=sample_prediction_data, + target=sample_target_data, + time=batch_time, + ), ) writer.flush() @@ -171,14 +201,14 @@ def test_append_batch( assert dataset["time"].units == "microseconds" assert dataset["init_time"].units == "microseconds since 1970-01-01 00:00:00" assert dataset["init_time"].calendar == calendar + horizontal_shape = (4, 5) if "lat" in coords else (6, 4, 5) for var_name in set(sample_prediction_data.keys()): var_data = dataset.variables[var_name][:] assert var_data.shape == ( - n_samples, + n_initial_conditions, n_timesteps, - 4, - 5, - ) # sample, time, lat, lon + *horizontal_shape, + ) assert not np.isnan(var_data).any(), "unexpected NaNs in prediction data" if var_name in sample_metadata: assert ( @@ -196,7 +226,7 @@ def test_append_batch( # Open the target output file and do smaller set of checks dataset = Dataset(tmp_path / "autoregressive_target.nc", "r") - coord_names = {"time", "init_time", "valid_time", "lat", "lon"} + coord_names = {"time", "init_time", "valid_time", *set(coords)} assert set(dataset.variables) == set(sample_target_data) | coord_names # Open the file again with xarray and check the time coordinates, @@ -222,7 +252,10 @@ def test_append_batch( ) xr.testing.assert_equal(ds["time"], expected_lead_times) expected_init_times = xr.DataArray( - [CALENDAR_CFTIME[calendar](*start_time) for _ in range(n_samples)], + [ + CALENDAR_CFTIME[calendar](*start_time) + for _ in range(n_initial_conditions) + ], dims=["sample"], ) expected_init_times = expected_init_times.assign_coords( @@ -245,7 +278,7 @@ def test_append_batch( assert same_count_each_timestep with xr.open_dataset(tmp_path / "monthly_mean_predictions.nc") as ds: - assert ds.counts.sum() == n_samples * n_timesteps + assert ds.counts.sum() == n_initial_conditions * n_timesteps assert np.sum(np.isnan(ds["precipitation"])) == 0 assert np.sum(np.isnan(ds["temp"])) == 0 assert np.sum(np.isnan(ds["pressure"])) == 0 @@ -253,6 +286,11 @@ def test_append_batch( assert np.all(ds.init_time.dt.year.values >= 0) assert np.all(ds.valid_time.dt.month.values >= 0) + @pytest.mark.parametrize( + "sample_target_data, sample_prediction_data", + [pytest.param((2, 3, 4, 5), (2, 3, 4, 5), id="LatLon")], + indirect=True, + ) @pytest.mark.parametrize( ["save_names"], [ @@ -271,31 +309,31 @@ def test_append_batch_save_names( n_samples = 2 writer = PairedDataWriter( str(tmp_path), - n_samples=n_samples, + n_initial_conditions=n_samples, n_timesteps=4, # unused timestep=TIMESTEP, - metadata=sample_metadata, + variable_metadata=sample_metadata, coords={"lat": np.arange(4), "lon": np.arange(5)}, enable_prediction_netcdfs=True, enable_video_netcdfs=False, enable_monthly_netcdfs=True, save_names=save_names, enable_histogram_netcdfs=True, - prognostic_names=save_names or [], ) start_time = (2020, 1, 1, 0, 0, 0) end_time = (2020, 1, 1, 12, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time, end_time=end_time, - freq="6H", - n_samples=n_samples, + freq="6h", + n_initial_conditions=n_samples, ) writer.append_batch( - sample_target_data, - sample_prediction_data, - start_timestep=0, - batch_times=batch_times, + batch=PairedData( + prediction=sample_prediction_data, + target=sample_target_data, + time=batch_time, + ), ) writer.flush() dataset = Dataset(tmp_path / "autoregressive_predictions.nc", "r") @@ -329,88 +367,99 @@ def test_append_batch_save_names( } ) + @pytest.mark.parametrize( + "sample_target_data, sample_prediction_data", + [pytest.param((2, 3, 4, 5), (2, 3, 4, 5), id="LatLon")], + indirect=True, + ) def test_append_batch_data_time_mismatch( self, sample_metadata, sample_target_data, sample_prediction_data, tmp_path ): n_samples = 2 writer = PairedDataWriter( str(tmp_path), - n_samples=n_samples, + n_initial_conditions=n_samples, n_timesteps=3, timestep=TIMESTEP, - metadata=sample_metadata, + variable_metadata=sample_metadata, coords={"lat": np.arange(4), "lon": np.arange(5)}, enable_prediction_netcdfs=True, enable_video_netcdfs=False, enable_monthly_netcdfs=True, save_names=None, enable_histogram_netcdfs=True, - prognostic_names=[], ) start_time = (2020, 1, 1, 0, 0, 0) end_time = (2020, 1, 1, 12, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time, end_time=end_time, - freq="6H", - n_samples=n_samples + 1, + freq="6h", + n_initial_conditions=n_samples + 1, ) with pytest.raises(ValueError): writer.append_batch( - sample_target_data, - sample_prediction_data, - start_timestep=0, - batch_times=batch_times, + batch=PairedData( + prediction=sample_prediction_data, + target=sample_target_data, + time=batch_time, + ), ) def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar): n_samples = 2 n_timesteps = 8 coarsen_factor = 2 + device = get_device() prediction_data = { - "temp": torch.rand((n_samples, n_timesteps // coarsen_factor, 4, 5)), - "pressure": torch.rand((n_samples, n_timesteps // coarsen_factor, 4, 5)), + "temp": torch.rand( + (n_samples, n_timesteps // coarsen_factor, 4, 5), device=device + ), + "pressure": torch.rand( + (n_samples, n_timesteps // coarsen_factor, 4, 5), device=device + ), } writer = DataWriter( str(tmp_path), - n_samples=n_samples, + n_initial_conditions=n_samples, n_timesteps=n_timesteps, - metadata=sample_metadata, + variable_metadata=sample_metadata, coords={"lat": np.arange(4), "lon": np.arange(5)}, timestep=TIMESTEP, enable_prediction_netcdfs=True, enable_monthly_netcdfs=True, save_names=None, - prognostic_names=["temp"], time_coarsen=TimeCoarsenConfig(coarsen_factor), ) start_time = (2020, 1, 1, 0, 0, 0) end_time = (2020, 1, 1, 18, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time, end_time=end_time, - freq="6H", - n_samples=n_samples, + freq="6h", + n_initial_conditions=n_samples, calendar=calendar, ) writer.append_batch( - prediction_data, - start_timestep=0, - batch_times=batch_times, + batch=BatchData( + data=prediction_data, + time=batch_time, + ), ) start_time_2 = (2020, 1, 2, 0, 0, 0) end_time_2 = (2020, 1, 2, 18, 0, 0) - batch_times = self.get_batch_times( + batch_time = self.get_batch_time( start_time=start_time_2, end_time=end_time_2, - freq="6H", - n_samples=n_samples, + freq="6h", + n_initial_conditions=n_samples, calendar=calendar, ) writer.append_batch( - prediction_data, - start_timestep=4, - batch_times=batch_times, + batch=BatchData( + data=prediction_data, + time=batch_time, + ), ) writer.flush() @@ -428,11 +477,9 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar) assert np.all(ds.init_time.dt.year.values >= 0) assert np.all(ds.valid_time.dt.month.values >= 0) - xr.open_dataset(tmp_path / "restart.nc") - @pytest.mark.parametrize( - ["init_times", "batch_times", "expected"], + ["init_times", "batch_time", "expected"], [ pytest.param( np.array([cftime.DatetimeJulian(2020, 1, 1, 0, 0, 0) for _ in range(3)]), @@ -440,7 +487,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar) [ xr.cftime_range( cftime.DatetimeJulian(2020, 1, 1, 0, 0, 0), - freq="6H", + freq="6h", periods=3, ).values for _ in range(3) @@ -462,7 +509,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar) [ xr.cftime_range( cftime.DatetimeJulian(2020, 1, 1, 6 * i, 0, 0), - freq="6H", + freq="6h", periods=3, ) for i in range(3) @@ -484,7 +531,7 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar) [ xr.cftime_range( cftime.DatetimeJulian(2020, 1, 2, 6 * i, 0, 0), - freq="6H", + freq="6h", periods=3, ) for i in range(3) @@ -500,50 +547,50 @@ def test_prediction_only_append_batch(self, sample_metadata, tmp_path, calendar) ), ], ) -def test_get_batch_lead_times_microseconds(init_times, batch_times, expected): - lead_time_seconds = get_batch_lead_times_microseconds(init_times, batch_times) +def test_get_batch_lead_times_microseconds(init_times, batch_time, expected): + lead_time_seconds = get_batch_lead_time_microseconds(init_times, batch_time) assert lead_time_seconds.shape == expected.shape np.testing.assert_equal(lead_time_seconds, expected) -def test_get_batch_lead_times_microseconds_length_mismatch(): +def test_get_batch_lead_time_microseconds_length_mismatch(): init_times = np.array( [cftime.DatetimeJulian(2020, 1, 1, 6 * i, 0, 0) for i in range(3)] ) - batch_times = np.array( + batch_time = np.array( [ xr.cftime_range( cftime.DatetimeJulian(2020, 1, 2, 6 * i, 0, 0), - freq="6H", + freq="6h", periods=3, ).values for i in range(2) ], ) with pytest.raises(ValueError): - get_batch_lead_times_microseconds(init_times, batch_times) + get_batch_lead_time_microseconds(init_times, batch_time) -def test_get_batch_lead_times_microseconds_inconsistent_samples(): +def test_get_batch_lead_time_microseconds_inconsistent_samples(): init_times = np.array( [cftime.DatetimeJulian(2020, 1, 1, 6, 0, 0) for _ in range(2)] ) - batch_times = np.array( + batch_time = np.array( [ xr.cftime_range( cftime.DatetimeJulian(2020, 1, 1, 6, 0, 0), - freq="6H", + freq="6h", periods=3, ), xr.cftime_range( cftime.DatetimeJulian(2020, 1, 1, 12, 0, 0), - freq="6H", + freq="6h", periods=3, ), ] ) with pytest.raises(ValueError): - get_batch_lead_times_microseconds(init_times, batch_times) + get_batch_lead_time_microseconds(init_times, batch_time) @pytest.mark.parametrize( @@ -554,17 +601,17 @@ def test_get_batch_lead_times_microseconds_inconsistent_samples(): pytest.param(1e6, True, id="1_000_000_years_fails"), ], ) -def test_get_batch_lead_times_microseconds_overflow(years_ahead, overflow): +def test_get_batch_lead_time_microseconds_overflow(years_ahead, overflow): init_times = np.array([cftime.DatetimeNoLeap(2020, 1, 1)]) - batch_times = np.array([cftime.DatetimeNoLeap(2020 + years_ahead, 1, 1)])[:, None] + batch_time = np.array([cftime.DatetimeNoLeap(2020 + years_ahead, 1, 1)])[:, None] days_per_year_noleap = 365 seconds_per_day = 86400 expected_lead_time_microseconds = ( MICROSECONDS_PER_SECOND * seconds_per_day * days_per_year_noleap * years_ahead ) if not overflow: - lead_time = get_batch_lead_times_microseconds(init_times, batch_times) + lead_time = get_batch_lead_time_microseconds(init_times, batch_time) assert lead_time.item() == expected_lead_time_microseconds else: with pytest.raises(OverflowError): - get_batch_lead_times_microseconds(init_times, batch_times) + get_batch_lead_time_microseconds(init_times, batch_time) diff --git a/fme/fme/ace/inference/data_writer/test_main.py b/fme/fme/ace/inference/data_writer/test_main.py new file mode 100644 index 0000000..95b311c --- /dev/null +++ b/fme/fme/ace/inference/data_writer/test_main.py @@ -0,0 +1,89 @@ +import os +import tempfile + +import numpy as np +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData +from fme.ace.inference.data_writer.main import _write +from fme.core.dataset.data_typing import VariableMetadata + + +def test_write_single_timestep(): + n_samples = 2 + n_lat = 4 + n_lon = 5 + n_time = 1 + batch = BatchData.new_on_cpu( + data={"air_temperature": torch.rand((n_samples, n_time, n_lat, n_lon))}, + time=xr.DataArray(np.random.rand(n_samples, n_time), dims=["sample", "time"]), + horizontal_dims=["lat", "lon"], + ) + with tempfile.TemporaryDirectory() as tmpdir: + _write( + data=batch, + path=tmpdir, + filename="initial_condition.nc", + variable_metadata={ + "air_temperature": VariableMetadata( + long_name="Air Temperature", units="K" + ) + }, + coords={"lat": np.arange(n_lat), "lon": np.arange(n_lon)}, + ) + filename = os.path.join(tmpdir, "initial_condition.nc") + assert os.path.exists(filename) + with xr.open_dataset(filename) as ds: + assert "air_temperature" in ds + assert ds.air_temperature.shape == (n_samples, n_lat, n_lon) + assert ds.time.shape == (n_samples,) + assert ds.air_temperature.dims == ("sample", "lat", "lon") + xr.testing.assert_allclose(ds.time, batch.time.isel(time=0)) + np.testing.assert_allclose( + ds.air_temperature.values, + batch.data["air_temperature"].squeeze(dim=1).cpu().numpy(), + ) + np.testing.assert_allclose(ds.coords["lat"].values, np.arange(n_lat)) + np.testing.assert_allclose(ds.coords["lon"].values, np.arange(n_lon)) + assert ds.air_temperature.attrs["long_name"] == "Air Temperature" + assert ds.air_temperature.attrs["units"] == "K" + + +def test_write_multiple_timesteps(): + n_samples = 2 + n_lat = 4 + n_lon = 5 + n_time = 2 + batch = BatchData.new_on_cpu( + data={"air_temperature": torch.rand((n_samples, n_time, n_lat, n_lon))}, + time=xr.DataArray(np.random.rand(n_samples, n_time), dims=["sample", "time"]), + horizontal_dims=["lat", "lon"], + ) + with tempfile.TemporaryDirectory() as tmpdir: + _write( + data=batch, + path=tmpdir, + filename="initial_condition.nc", + variable_metadata={ + "air_temperature": VariableMetadata( + long_name="Air Temperature", units="K" + ) + }, + coords={"lat": np.arange(n_lat), "lon": np.arange(n_lon)}, + ) + filename = os.path.join(tmpdir, "initial_condition.nc") + assert os.path.exists(filename) + with xr.open_dataset(filename) as ds: + assert "air_temperature" in ds + assert ds.air_temperature.shape == (n_samples, n_time, n_lat, n_lon) + assert ds.time.shape == (n_samples, n_time) + assert ds.air_temperature.dims == ("sample", "time", "lat", "lon") + np.testing.assert_allclose(ds.time.values, batch.time.values) + np.testing.assert_allclose( + ds.air_temperature.values, batch.data["air_temperature"].cpu().numpy() + ) + np.testing.assert_allclose(ds.coords["lat"].values, np.arange(n_lat)) + np.testing.assert_allclose(ds.coords["lon"].values, np.arange(n_lon)) + assert ds.air_temperature.attrs["long_name"] == "Air Temperature" + assert ds.air_temperature.attrs["units"] == "K" diff --git a/fme/fme/ace/inference/data_writer/test_monthly.py b/fme/fme/ace/inference/data_writer/test_monthly.py index 4dac37b..48eb004 100644 --- a/fme/fme/ace/inference/data_writer/test_monthly.py +++ b/fme/fme/ace/inference/data_writer/test_monthly.py @@ -13,7 +13,7 @@ get_days_since_reference, months_for_timesteps, ) -from fme.core.data_loading.data_typing import VariableMetadata +from fme.core.dataset.data_typing import VariableMetadata TIMESTEP = datetime.timedelta(hours=6) @@ -38,7 +38,7 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int): n_samples=n_samples, n_months=24, save_names=None, - metadata={"x": VariableMetadata(units="m", long_name="x_name")}, + variable_metadata={"x": VariableMetadata(units="m", long_name="x_name")}, coords={}, ) month_values = [] @@ -51,7 +51,7 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int): month_data = {"x": x_window} initial_time = cftime.DatetimeProlepticGregorian(year, month, 1, 0, 0, 0) for i_write in range(n_writes): - times = xr.DataArray( + time = xr.DataArray( [ [ initial_time + datetime.timedelta(hours=6 * i_write) @@ -61,10 +61,8 @@ def test_monthly_data_writer(tmpdir, window_size: int, n_writes: int): ], dims=["sample", "time"], ) - assert times.shape == (n_samples, window_size) - writer.append_batch( - data=month_data, start_timestep=0, batch_times=times - ) + assert time.shape == (n_samples, window_size) + writer.append_batch(data=month_data, start_timestep=0, batch_time=time) writer.flush() written = xr.open_dataset(str(tmpdir / "monthly_mean_predictions.nc")) assert written["x"].shape == (n_samples, 24, n_lat, n_lon) @@ -93,21 +91,46 @@ def test_months_for_timesteps(n_timesteps: int, min_expected: int): assert months_for_timesteps(n_timesteps, TIMESTEP) >= min_expected -def test_get_days_since_reference(): - years = np.array([2020, 2021]) - months = np.array([0, 1]) # expects zero-indexed months - reference_date = cftime.DatetimeProlepticGregorian(2020, 1, 1) +@pytest.mark.parametrize("num_years", [2, 500]) +@pytest.mark.parametrize("calendar", ["proleptic_gregorian", "noleap"]) +def test_get_days_since_reference(num_years, calendar): + first_year = 2020 + final_year = first_year + num_years - 1 + years = np.array([i for i in range(first_year, final_year + 1)]) + months = np.zeros((num_years,), dtype=int) + # For last year set month to 1 + months[-1] = 1 + if calendar == "proleptic_gregorian": + reference_date = cftime.DatetimeProlepticGregorian(2020, 1, 1) + else: + reference_date = cftime.DatetimeNoLeap(2020, 1, 1) n_months = 3 - calendar = "proleptic_gregorian" days = get_days_since_reference(years, months, reference_date, n_months, calendar) - assert days.shape == (2, 3) - assert days[0, 0] == 0 - assert days[0, 1] == 31 - assert days[0, 2] == 31 + 29 - # 2020 is a leap year - assert days[1, 0] == 366 + 31 - assert days[1, 1] == 366 + 31 + 28 - assert days[1, 2] == 366 + 31 + 28 + 31 + assert days.shape == (num_years, 3) + # 2020 is a leap year in proleptic_gregorian + if calendar == "proleptic_gregorian": + assert days[0, 0] == 0 + assert days[0, 1] == 31 + assert days[0, 2] == 31 + 29 + if num_years == 2: + assert days[1, 0] == 366 + 31 + assert days[1, 1] == 366 + 31 + 28 + assert days[1, 2] == 366 + 31 + 28 + 31 + if num_years == 500: + # 121 is number of leap days + assert days[499, 0] == 182135 + 121 + 31 + assert days[499, 1] == 182135 + 121 + 31 + 28 + if calendar == "noleap": + assert days[0, 0] == 0 + assert days[0, 1] == 31 + assert days[0, 2] == 31 + 28 + if num_years == 2: + assert days[1, 0] == 365 + 31 + assert days[1, 1] == 365 + 31 + 28 + assert days[1, 2] == 365 + 31 + 28 + 31 + if num_years == 500: + assert days[499, 0] == 182135 + 31 + assert days[499, 1] == 182135 + 31 + 28 @pytest.mark.parametrize( diff --git a/fme/fme/ace/inference/data_writer/test_restart.py b/fme/fme/ace/inference/data_writer/test_restart.py deleted file mode 100644 index a858830..0000000 --- a/fme/fme/ace/inference/data_writer/test_restart.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np -import torch -import xarray as xr - -from fme.ace.inference.data_writer.restart import RestartWriter -from fme.core.data_loading.data_typing import VariableMetadata - - -def test_restart_saves_last_step(tmpdir): - """ - If multiple steps are configured as restart steps, the last one should be saved. - """ - n_sample: int = 3 - n_time: int = 2 - n_lat = 10 - n_lon = 20 - lat = np.linspace(-90, 90, n_lat) - lon = np.linspace(-180, 180, n_lon) - writer = RestartWriter( - path=tmpdir, - is_restart_step=lambda i: True, - prognostic_names=["a", "b"], - metadata={"a": VariableMetadata(long_name="var_a", units="m")}, - coords={"lon": lon, "lat": lat}, - ) - data = { - "a": torch.randn(n_sample, n_time, n_lat, n_lon), - "b": torch.randn(n_sample, n_time, n_lat, n_lon), - } - batch_times = xr.DataArray( - data=np.random.uniform(size=(n_sample, n_time)), - dims=( - "sample", - "time", - ), - ) - writer.append_batch(data, 0, batch_times) - ds = xr.open_dataset(str(tmpdir / "restart.nc")) - np.testing.assert_allclose(ds.a.values, data["a"][:, -1].cpu().numpy()) - np.testing.assert_allclose(ds.b.values, data["b"][:, -1].cpu().numpy()) - np.testing.assert_allclose(ds.time.values, batch_times[:, -1].values) - assert len(ds.b.attrs) == 0 - assert len(ds.a.attrs) == 2 - assert ds.a.attrs["long_name"] == "var_a" - assert ds.a.attrs["units"] == "m" - np.testing.assert_allclose(ds.lon.values, lon) - np.testing.assert_allclose(ds.lat.values, lat) - assert ds.attrs["timestep"] == 1 - - -def test_restart_saves_configured_step(tmpdir): - """ - If a specific step is configured as a restart step, that step should be saved. - """ - n_sample: int = 3 - i_time_start = 4 - i_time_target = 6 - n_time: int = 4 - n_lat = 10 - n_lon = 20 - lat = np.linspace(-90, 90, n_lat) - lon = np.linspace(-180, 180, n_lon) - writer = RestartWriter( - path=tmpdir, - is_restart_step=lambda i: i == i_time_target, - prognostic_names=["a", "b"], - metadata={"a": VariableMetadata(long_name="var_a", units="m")}, - coords={"lon": lon, "lat": lat}, - ) - data = { - "a": torch.randn(n_sample, n_time, n_lat, n_lon), - "b": torch.randn(n_sample, n_time, n_lat, n_lon), - } - batch_times = xr.DataArray( - data=np.random.uniform(size=(n_sample, n_time)), - dims=( - "sample", - "time", - ), - ) - writer.append_batch(data, i_time_start, batch_times) - ds = xr.open_dataset(str(tmpdir / "restart.nc")) - np.testing.assert_allclose( - ds.a.values, data["a"][:, i_time_target - i_time_start].cpu().numpy() - ) - np.testing.assert_allclose( - ds.b.values, data["b"][:, i_time_target - i_time_start].cpu().numpy() - ) - np.testing.assert_allclose( - ds.time.values, batch_times[:, i_time_target - i_time_start].values - ) - assert len(ds.b.attrs) == 0 - assert len(ds.a.attrs) == 2 - assert ds.a.attrs["long_name"] == "var_a" - assert ds.a.attrs["units"] == "m" - np.testing.assert_allclose(ds.lon.values, lon) - np.testing.assert_allclose(ds.lat.values, lat) - assert ds.attrs["timestep"] == i_time_target - - -def test_restart_does_not_save(tmpdir): - """ - If no step is configured to save as restart, no restart should be saved. - """ - n_sample: int = 3 - n_time: int = 2 - n_lat = 10 - n_lon = 20 - lat = np.linspace(-90, 90, n_lat) - lon = np.linspace(-180, 180, n_lon) - writer = RestartWriter( - path=tmpdir, - is_restart_step=lambda i: False, - prognostic_names=["a", "b"], - metadata={"a": VariableMetadata(long_name="var_a", units="m")}, - coords={"lon": lon, "lat": lat}, - ) - data = { - "a": torch.randn(n_sample, n_time, n_lat, n_lon), - "b": torch.randn(n_sample, n_time, n_lat, n_lon), - } - batch_times = xr.DataArray( - data=np.random.uniform(size=(n_sample, n_time)), - dims=( - "sample", - "time", - ), - ) - writer.append_batch(data, 0, batch_times) - assert not (tmpdir / "restart.nc").exists() diff --git a/fme/fme/ace/inference/data_writer/test_time_coarsen.py b/fme/fme/ace/inference/data_writer/test_time_coarsen.py index 3902f03..2cb09cb 100644 --- a/fme/fme/ace/inference/data_writer/test_time_coarsen.py +++ b/fme/fme/ace/inference/data_writer/test_time_coarsen.py @@ -21,11 +21,11 @@ def get_windowed_batch(dim_sizes: Sequence[int], start_time: Sequence[int]): .movedim(3, 1) ) target = {VARNAME: data} - times = get_batch_times(n_timesteps, start_time, n_samples=n_samples) - return target, times + time = get_batch_time(n_timesteps, start_time, n_samples=n_samples) + return target, time -def get_batch_times( +def get_batch_time( n_timesteps: int, start_time: Sequence[int], n_samples: int = 2, @@ -49,7 +49,7 @@ def _get_time_array( hours=6 * start_n_offset ) time_index = xr.cftime_range( - start=start_time, periods=n_timesteps, freq=f"{freq_hrs}H", calendar="julian" + start=start_time, periods=n_timesteps, freq=f"{freq_hrs}h", calendar="julian" ) return xr.DataArray(data=time_index, dims=["time"]).drop_vars(["time"]) @@ -59,7 +59,7 @@ def _get_time_array( "coarsen_factor", "start_timestep", "expected_coarsened_data", - "expected_coarsened_times", + "expected_coarsened_time", "expected_coarsened_start_timestep", ], [ @@ -67,7 +67,7 @@ def _get_time_array( 1, 0, [0.0, 1.0, 2.0, 3.0], - get_batch_times(start_time=(2020, 1, 1, 0, 0, 0), n_timesteps=4), + get_batch_time(start_time=(2020, 1, 1, 0, 0, 0), n_timesteps=4), 0, id="coarsen_factor_1", ), @@ -75,7 +75,7 @@ def _get_time_array( 2, 0, [0.5, 2.5], - get_batch_times( + get_batch_time( start_time=(2020, 1, 1, 3, 0, 0), n_timesteps=2, freq_hrs=12 ), 0, @@ -85,7 +85,7 @@ def _get_time_array( 2, 3, [0.5, 2.5], - get_batch_times( + get_batch_time( start_time=(2020, 1, 1, 3, 0, 0), n_timesteps=2, freq_hrs=12 ), 1, @@ -95,7 +95,7 @@ def _get_time_array( 4, 0, [1.5], - get_batch_times( + get_batch_time( start_time=(2020, 1, 1, 9, 0, 0), n_timesteps=1, freq_hrs=24 ), 0, @@ -107,41 +107,45 @@ def test_time_coarsen( coarsen_factor: int, start_timestep: int, expected_coarsened_data: Sequence[float], - expected_coarsened_times: Sequence[cftime.DatetimeJulian], + expected_coarsened_time: Sequence[cftime.DatetimeJulian], expected_coarsened_start_timestep: int, dim_sizes: Sequence[int] = DIM_SIZES, ): - target, times = get_windowed_batch( + target, time = get_windowed_batch( dim_sizes=dim_sizes, start_time=(2020, 1, 1, 0, 0, 0) ) ( target_coarsened, coarsened_start_timestep, - times_coarsened, + time_coarsened, ) = coarsen_batch( data=target, start_timestep=start_timestep, - batch_times=times, + batch_time=time, coarsen_factor=coarsen_factor, ) # check the coarsened data time dim size assert target_coarsened[VARNAME].size(dim=1) == len( expected_coarsened_data ), "target coarsened time dim" - assert times_coarsened.sizes["time"] == len( + assert time_coarsened.sizes["time"] == len( expected_coarsened_data - ), "times coarsened time dim" + ), "time coarsened time dim" # check the coarsened data values n_samples, _, n_lat, n_lon = dim_sizes - torch.testing.assert_close( - target_coarsened[VARNAME], - torch.tensor(expected_coarsened_data, dtype=torch.float64) - .repeat(n_samples, n_lat, n_lon, 1) - .movedim(3, 1), - ), "target coarsened value" + ( + torch.testing.assert_close( + target_coarsened[VARNAME], + torch.tensor(expected_coarsened_data, dtype=torch.float64) + .repeat(n_samples, n_lat, n_lon, 1) + .movedim(3, 1), + ), + "target coarsened value", + ) # check the coarsened start timestep assert coarsened_start_timestep == expected_coarsened_start_timestep # check the coarsened data time coordinate values - xr.testing.assert_allclose( - times_coarsened, expected_coarsened_times - ), "times initial condition value" + ( + xr.testing.assert_allclose(time_coarsened, expected_coarsened_time), + "time initial condition value", + ) diff --git a/fme/fme/ace/inference/data_writer/time_coarsen.py b/fme/fme/ace/inference/data_writer/time_coarsen.py index 21d294a..a677b74 100644 --- a/fme/fme/ace/inference/data_writer/time_coarsen.py +++ b/fme/fme/ace/inference/data_writer/time_coarsen.py @@ -14,7 +14,7 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): pass @@ -27,7 +27,7 @@ def append_batch( self, data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): pass @@ -66,7 +66,7 @@ def build(self, data_writer: _DataWriter) -> "TimeCoarsen": ) def n_coarsened_timesteps(self, n_timesteps: int) -> int: - """Assumes initial condition is NOT in n_timesteps""" + """Assumes initial condition is NOT in n_timesteps.""" return (n_timesteps) // self.coarsen_factor @@ -86,13 +86,13 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): (target_coarsened, start_timestep, batch_times_coarsened) = coarsen_batch( - target, start_timestep, batch_times, self._coarsen_factor + target, start_timestep, batch_time, self._coarsen_factor ) (prediction_coarsened, _, _) = coarsen_batch( - prediction, start_timestep, batch_times, self._coarsen_factor + prediction, start_timestep, batch_time, self._coarsen_factor ) self._data_writer.append_batch( target_coarsened, @@ -120,10 +120,10 @@ def append_batch( self, data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): (data_coarsened, start_timestep, batch_times_coarsened) = coarsen_batch( - data, start_timestep, batch_times, self._coarsen_factor + data, start_timestep, batch_time, self._coarsen_factor ) self._data_writer.append_batch( data_coarsened, @@ -138,13 +138,13 @@ def flush(self): def coarsen_batch( data: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, coarsen_factor: int, ) -> Tuple[Dict[str, torch.Tensor], int, xr.DataArray]: data_coarsened = _coarsen_tensor_dict(data, coarsen_factor) start_timestep = start_timestep // coarsen_factor - batch_times_coarsened = batch_times.coarsen({TIME_DIM_NAME: coarsen_factor}).mean() - return data_coarsened, start_timestep, batch_times_coarsened + batch_time_coarsened = batch_time.coarsen({TIME_DIM_NAME: coarsen_factor}).mean() + return data_coarsened, start_timestep, batch_time_coarsened def _coarsen_tensor_dict( diff --git a/fme/fme/ace/inference/data_writer/utils.py b/fme/fme/ace/inference/data_writer/utils.py index 6fd7e4d..52bfb7d 100644 --- a/fme/fme/ace/inference/data_writer/utils.py +++ b/fme/fme/ace/inference/data_writer/utils.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Iterable, Optional, Set, TypeVar T = TypeVar("T") @@ -17,3 +18,21 @@ def get_all_names( return variables else: return variables.intersection(set(allowlist)) + + +@dataclass +class DimInfo: + name: str + index: int + + +DIM_INFO_LATLON = [ + DimInfo(name="lat", index=-2), + DimInfo(name="lon", index=-1), +] + +DIM_INFO_HEALPIX = [ + DimInfo(name="face", index=-3), + DimInfo(name="height", index=-2), + DimInfo(name="width", index=-1), +] diff --git a/fme/fme/ace/inference/data_writer/video.py b/fme/fme/ace/inference/data_writer/video.py index fd95d01..e55ad80 100644 --- a/fme/fme/ace/inference/data_writer/video.py +++ b/fme/fme/ace/inference/data_writer/video.py @@ -5,8 +5,8 @@ import torch import xarray as xr -from fme.core.aggregator.inference.video import VideoAggregator -from fme.core.data_loading.data_typing import VariableMetadata +from fme.ace.aggregator.inference.video import VideoAggregator +from fme.core.dataset.data_typing import VariableMetadata class PairedVideoDataWriter: @@ -18,25 +18,27 @@ def __init__( self, path: str, n_timesteps: int, - metadata: Mapping[str, VariableMetadata], + variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], ): """ Args: - filename: Path to write netCDF file(s). + path: Directory within which to write the file. n_samples: Number of samples to write to the file. n_timesteps: Number of timesteps to write to the file. - metadata: Metadata for each variable to be written to the file. + variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. """ self.path = path self._metrics_filename = str( Path(path) / "reduced_autoregressive_predictions.nc" ) - self.metadata = metadata + self.variable_metadata = variable_metadata self.coords = coords self._video = VideoAggregator( - n_timesteps=n_timesteps, enable_extended_videos=True, metadata=metadata + n_timesteps=n_timesteps, + enable_extended_videos=True, + variable_metadata=variable_metadata, ) def append_batch( @@ -44,7 +46,7 @@ def append_batch( target: Dict[str, torch.Tensor], prediction: Dict[str, torch.Tensor], start_timestep: int, - batch_times: xr.DataArray, + batch_time: xr.DataArray, ): """ Append a batch of data to the file. @@ -53,10 +55,9 @@ def append_batch( target: Target data. prediction: Prediction data. start_timestep: Timestep at which to start writing. - batch_times: Time coordinates for each sample in the batch. Unused. + batch_time: Time coordinate for each sample in the batch. Unused. """ self._video.record_batch( - loss=np.nan, target_data=target, gen_data=prediction, i_time_start=start_timestep, diff --git a/fme/fme/ace/inference/derived_variables.py b/fme/fme/ace/inference/derived_variables.py index c5a673d..429bcde 100644 --- a/fme/fme/ace/inference/derived_variables.py +++ b/fme/fme/ace/inference/derived_variables.py @@ -1,86 +1,76 @@ -import dataclasses import datetime import logging -from typing import Callable, Dict, List, MutableMapping, Optional +from typing import Callable, Dict, MutableMapping, Optional import torch from fme.core import metrics from fme.core.climate_data import ClimateData -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.core.coordinates import HybridSigmaPressureCoordinate from fme.core.device import get_device -from fme.core.stepper import SteppedData +DerivedVariableFunc = Callable[ + [ClimateData, HybridSigmaPressureCoordinate, datetime.timedelta], torch.Tensor +] -@dataclasses.dataclass -class DerivedVariableRegistryEntry: - func: Callable[[ClimateData, SigmaCoordinates, datetime.timedelta], torch.Tensor] - required_inputs: Optional[List[str]] = None +_DERIVED_VARIABLE_REGISTRY: MutableMapping[str, DerivedVariableFunc] = {} -_DERIVED_VARIABLE_REGISTRY: MutableMapping[str, DerivedVariableRegistryEntry] = {} +def register(func: DerivedVariableFunc): + label = func.__name__ + if label in _DERIVED_VARIABLE_REGISTRY: + raise ValueError(f"Function {label} has already been added to registry.") + _DERIVED_VARIABLE_REGISTRY[label] = func + return func -def register( - required_inputs: Optional[List[str]] = None, -): - """Decorator for registering a function that computes a derived variable.""" - - def decorator( - func: Callable[ - [ClimateData, SigmaCoordinates, datetime.timedelta], torch.Tensor - ] - ): - label = func.__name__ - if label in _DERIVED_VARIABLE_REGISTRY: - raise ValueError(f"Function {label} has already been added to registry.") - _DERIVED_VARIABLE_REGISTRY[label] = DerivedVariableRegistryEntry( - func=func, required_inputs=required_inputs - ) - return func - - return decorator - -@register() +@register def surface_pressure_due_to_dry_air( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ) -> torch.Tensor: return metrics.surface_pressure_due_to_dry_air( data.specific_total_water, data.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, + vertical_coordinate, ) -@register() +@register +def surface_pressure_due_to_dry_air_absolute_tendency( + data: ClimateData, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, +) -> torch.Tensor: + ps_dry = surface_pressure_due_to_dry_air(data, vertical_coordinate, timestep) + abs_ps_dry_tendency = torch.zeros_like(ps_dry) + abs_ps_dry_tendency[:, 1:] = torch.diff(ps_dry, n=1, dim=1).abs() + return abs_ps_dry_tendency + + +@register def total_water_path( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ) -> torch.Tensor: - return metrics.vertical_integral( + return vertical_coordinate.vertical_integral( data.specific_total_water, data.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, ) -@register() +@register def total_water_path_budget_residual( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): - total_water_path = metrics.vertical_integral( + total_water_path = vertical_coordinate.vertical_integral( data.specific_total_water, data.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, ) twp_total_tendency = (total_water_path[:, 1:] - total_water_path[:, :-1]) / ( timestep.total_seconds() @@ -95,23 +85,23 @@ def total_water_path_budget_residual( return twp_budget_residual -@register( - required_inputs=[ - "DSWRFtoa", - ] -) +@register def net_energy_flux_toa_into_atmosphere( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): - return data.data["DSWRFtoa"] - data.data["USWRFtoa"] - data.data["ULWRFtoa"] + return ( + data.toa_down_sw_radiative_flux + - data.toa_up_sw_radiative_flux + - data.toa_up_lw_radiative_flux + ) -@register() +@register def net_energy_flux_sfc_into_atmosphere( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): # property is defined as positive into surface, but want to compare to @@ -119,38 +109,36 @@ def net_energy_flux_sfc_into_atmosphere( return -data.net_surface_energy_flux_without_frozen_precip -@register(required_inputs=["DSWRFtoa"]) +@register def net_energy_flux_into_atmospheric_column( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): return net_energy_flux_sfc_into_atmosphere( - data, sigma_coordinates, timestep - ) + net_energy_flux_toa_into_atmosphere(data, sigma_coordinates, timestep) + data, vertical_coordinate, timestep + ) + net_energy_flux_toa_into_atmosphere(data, vertical_coordinate, timestep) -@register(required_inputs=["HGTsfc"]) +@register def column_moist_static_energy( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): - return metrics.vertical_integral( - data.moist_static_energy(sigma_coordinates), + return vertical_coordinate.vertical_integral( + data.moist_static_energy(vertical_coordinate), data.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, ) -@register(required_inputs=["HGTsfc"]) +@register def column_moist_static_energy_tendency( data: ClimateData, - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): - mse = column_moist_static_energy(data, sigma_coordinates, timestep) + mse = column_moist_static_energy(data, vertical_coordinate, timestep) diff = torch.diff(mse, n=1, dim=1) # Only the very first timestep in series is filled with nan; subsequent batches # drop the first step as it's the initial condition. @@ -164,32 +152,31 @@ def column_moist_static_energy_tendency( def _compute_derived_variable( data: Dict[str, torch.Tensor], - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, label: str, - derived_variable: DerivedVariableRegistryEntry, + derived_variable_func: DerivedVariableFunc, forcing_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Computes a derived variable and adds it to the given data. - If the required input data is not available, - no change will be made to the data. + The derived variable name must not already exist in the data. + + If any required input data are not available, + the derived variable will not be computed. Args: data: dictionary of data add the derived variable to. - sigma_coordinates: the vertical coordinate. + vertical_coordinate: the vertical coordinate. timestep: Timestep of the model. label: the name of the derived variable. - derived_variable: class indicating required names and function to compute. + derived_variable_func: derived variable function to compute. forcing_data: optional dictionary of forcing data needed for some derived variables. If necessary forcing inputs are missing, the derived variable will not be computed. Returns: - A new SteppedData instance with the derived variable added. - - Note: - Derived variables are only computed for the denormalized data in stepped. + A new data dictionary with the derived variable added. """ if label in data: raise ValueError( @@ -197,15 +184,15 @@ def _compute_derived_variable( "to overwrite existing variables with derived variables." ) new_data = data.copy() + if forcing_data is not None: + for key, value in forcing_data.items(): + if key not in data: + data[key] = value + climate_data = ClimateData(data) try: - if forcing_data and derived_variable.required_inputs: - for v in derived_variable.required_inputs: - if v not in forcing_data: - raise KeyError(v) - climate_data.data.update({v: forcing_data[v]}) - output = derived_variable.func(climate_data, sigma_coordinates, timestep) + output = derived_variable_func(climate_data, vertical_coordinate, timestep) except KeyError as key_error: logging.debug(f"Could not compute {label} because {key_error} is missing") else: # if no exception was raised @@ -215,43 +202,18 @@ def _compute_derived_variable( def compute_derived_quantities( data: Dict[str, torch.Tensor], - sigma_coordinates: SigmaCoordinates, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, - registry: MutableMapping[ - str, DerivedVariableRegistryEntry - ] = _DERIVED_VARIABLE_REGISTRY, forcing_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Computes all derived quantities from the given data.""" - for label, derived_variable in registry.items(): + for label, func in _DERIVED_VARIABLE_REGISTRY.items(): data = _compute_derived_variable( data, - sigma_coordinates, + vertical_coordinate, timestep, label, - derived_variable, + func, forcing_data=forcing_data, ) return data - - -def compute_stepped_derived_quantities( - stepped: SteppedData, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - registry: MutableMapping[ - str, DerivedVariableRegistryEntry - ] = _DERIVED_VARIABLE_REGISTRY, - forcing_data: Optional[Dict[str, torch.Tensor]] = None, -) -> SteppedData: - stepped.gen_data = compute_derived_quantities( - stepped.gen_data, sigma_coordinates, timestep, registry, forcing_data - ) - stepped.target_data = compute_derived_quantities( - stepped.target_data, - sigma_coordinates, - timestep, - registry, - forcing_data, - ) - return stepped diff --git a/fme/fme/ace/inference/evaluator.py b/fme/fme/ace/inference/evaluator.py index 1ce48ce..f3abbbb 100755 --- a/fme/fme/ace/inference/evaluator.py +++ b/fme/fme/ace/inference/evaluator.py @@ -2,9 +2,7 @@ import dataclasses import logging import os -import time -from pathlib import Path -from typing import Optional, Sequence +from typing import Callable, Optional import dacite import torch @@ -12,44 +10,53 @@ import fme import fme.core.logging_utils as logging_utils +from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig +from fme.ace.data_loading.batch_data import BatchData, InferenceGriddedData +from fme.ace.data_loading.getters import get_inference_data +from fme.ace.data_loading.inference import InferenceDataLoaderConfig from fme.ace.inference.data_writer import DataWriterConfig, PairedDataWriter from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig -from fme.ace.inference.loop import run_dataset_comparison, run_inference_evaluator -from fme.core import SingleModuleStepper -from fme.core.aggregator.inference import InferenceEvaluatorAggregatorConfig -from fme.core.data_loading.data_typing import GriddedData, SigmaCoordinates -from fme.core.data_loading.getters import get_inference_data -from fme.core.data_loading.inference import InferenceDataLoaderConfig +from fme.ace.inference.loop import ( + DeriverABC, + run_dataset_comparison, + write_reduced_metrics, +) +from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig from fme.core.dicts import to_flat_dict +from fme.core.generics.inference import get_record_to_wandb, run_inference from fme.core.logging_utils import LoggingConfig from fme.core.ocean import OceanConfig -from fme.core.stepper import SingleModuleStepperConfig -from fme.core.wandb import WandB +from fme.core.timing import GlobalTimer +from fme.core.typing_ import TensorDict, TensorMapping -def load_stepper_config(checkpoint_file: str) -> SingleModuleStepperConfig: +def load_stepper_config( + checkpoint_file: str, ocean_config: Optional[OceanConfig] +) -> SingleModuleStepperConfig: checkpoint = torch.load(checkpoint_file, map_location=fme.get_device()) - return SingleModuleStepperConfig.from_state(checkpoint["stepper"]["config"]) + config = SingleModuleStepperConfig.from_state(checkpoint["stepper"]["config"]) + if ocean_config is not None: + logging.info( + "Overriding training ocean configuration with the inference ocean config." + ) + config.ocean = ocean_config + return config def load_stepper( checkpoint_file: str, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, ocean_config: Optional[OceanConfig] = None, ) -> SingleModuleStepper: checkpoint = torch.load(checkpoint_file, map_location=fme.get_device()) - stepper = SingleModuleStepper.from_state( - checkpoint["stepper"], area=area, sigma_coordinates=sigma_coordinates - ) + stepper = SingleModuleStepper.from_state(checkpoint["stepper"]) if ocean_config is not None: logging.info( "Overriding training ocean configuration with the inference ocean config." ) new_ocean = ocean_config.build( - stepper.in_packer.names, stepper.out_packer.names, stepper.timestep + stepper.in_names, stepper.out_names, stepper.timestep ) - stepper.ocean = new_ocean + stepper.replace_ocean(new_ocean) return stepper @@ -76,7 +83,7 @@ class InferenceEvaluatorConfig: """ Configuration for running inference including comparison to reference data. - Attributes: + Parameters: experiment_dir: Directory to save results to. n_forward_steps: Number of steps to run the model forward for. checkpoint_path: Path to stepper checkpoint to load. @@ -129,41 +136,22 @@ def configure_wandb(self, env_vars: Optional[dict] = None, **kwargs): def clean_wandb(self): self.logging.clean_wandb(self.experiment_dir) - def configure_gcs(self): - self.logging.configure_gcs() - - def load_stepper( - self, area: torch.Tensor, sigma_coordinates: SigmaCoordinates - ) -> SingleModuleStepper: - """ - Args: - area: A tensor of shape (n_lat, n_lon) containing the area of - each grid cell. - sigma_coordinates: The sigma coordinates of the model. - """ + def load_stepper(self) -> SingleModuleStepper: logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}") - stepper = load_stepper( - self.checkpoint_path, - area=area, - sigma_coordinates=sigma_coordinates, - ocean_config=self.ocean, - ) + stepper = load_stepper(self.checkpoint_path, ocean_config=self.ocean) return stepper def load_stepper_config(self) -> SingleModuleStepperConfig: logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}") - return load_stepper_config(self.checkpoint_path) + return load_stepper_config(self.checkpoint_path, ocean_config=self.ocean) - def get_data_writer( - self, data: GriddedData, prognostic_names: Sequence[str] - ) -> PairedDataWriter: + def get_data_writer(self, data: InferenceGriddedData) -> PairedDataWriter: return self.data_writer.build_paired( experiment_dir=self.experiment_dir, - n_samples=self.loader.n_samples, + n_initial_conditions=self.loader.n_initial_conditions, n_timesteps=self.n_forward_steps, timestep=data.timestep, - prognostic_names=prognostic_names, - metadata=data.metadata, + variable_metadata=data.variable_metadata, coords=data.coords, ) @@ -180,10 +168,45 @@ def main(yaml_config: str): os.makedirs(config.experiment_dir, exist_ok=True) with open(os.path.join(config.experiment_dir, "config.yaml"), "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) - return run_evaluator_from_config(config) + with GlobalTimer(): + return run_evaluator_from_config(config) + + +class _Deriver(DeriverABC): + """ + DeriverABC implementation for dataset comparison. + """ + + def __init__( + self, + n_ic_timesteps: int, + derive_func: Callable[[TensorMapping, TensorMapping], TensorDict], + ): + self._n_ic_timesteps = n_ic_timesteps + self._derive_func = derive_func + + @property + def n_ic_timesteps(self) -> int: + return self._n_ic_timesteps + + def get_forward_data( + self, data: BatchData, compute_derived_variables: bool = False + ) -> BatchData: + if compute_derived_variables: + timer = GlobalTimer.get_instance() + with timer.context("compute_derived_variables"): + data = data.compute_derived_variables( + derive_func=self._derive_func, + forcing_data=data, + ) + return data.remove_initial_condition(self._n_ic_timesteps) def run_evaluator_from_config(config: InferenceEvaluatorConfig): + timer = GlobalTimer.get_instance() + timer.start_outer("inference") + timer.start("initialization") + if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") @@ -196,22 +219,22 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig): logging_utils.log_versions() logging.info(f"Current device is {fme.get_device()}") - start_time = time.time() stepper_config = config.load_stepper_config() logging.info("Loading inference data") - data_requirements = stepper_config.get_data_requirements( - n_forward_steps=config.n_forward_steps + window_requirements = stepper_config.get_evaluation_window_data_requirements( + n_forward_steps=config.forward_steps_in_memory + ) + initial_condition_requirements = ( + stepper_config.get_prognostic_state_data_requirements() ) data = get_inference_data( - config.loader, - config.forward_steps_in_memory, - data_requirements, + config=config.loader, + total_forward_steps=config.n_forward_steps, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, ) - stepper = config.load_stepper( - data.area_weights.to(fme.get_device()), - sigma_coordinates=data.sigma_coordinates.to(fme.get_device()), - ) + stepper = config.load_stepper() if stepper.timestep != data.timestep: raise ValueError( f"Timestep of the loaded stepper, {stepper.timestep}, does not " @@ -220,81 +243,86 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig): aggregator_config: InferenceEvaluatorAggregatorConfig = config.aggregator for batch in data.loader: - initial_times = batch.times.isel(time=0) + initial_time = batch.time.isel(time=0) break aggregator = aggregator_config.build( - area_weights=data.area_weights.to(fme.get_device()), - sigma_coordinates=data.sigma_coordinates, + vertical_coordinate=data.vertical_coordinate, + horizontal_coordinates=data.horizontal_coordinates, timestep=data.timestep, record_step_20=config.n_forward_steps >= 20, - n_timesteps=config.n_forward_steps + 1, - metadata=data.metadata, - data_grid=data.grid, - initial_times=initial_times, + n_timesteps=config.n_forward_steps + stepper_config.n_ic_timesteps, + variable_metadata=data.variable_metadata, + initial_time=initial_time, + channel_mean_names=stepper.out_names, + normalize=stepper.normalizer.normalize, ) - writer = config.get_data_writer(data, stepper.prognostic_names) + writer = config.get_data_writer(data) + timer.stop() logging.info("Starting inference") + record_logs = get_record_to_wandb(label="inference") if config.prediction_loader is not None: prediction_data = get_inference_data( config.prediction_loader, - config.forward_steps_in_memory, - data_requirements, + total_forward_steps=config.n_forward_steps, + window_requirements=window_requirements, + initial_condition=initial_condition_requirements, ) - - timers = run_dataset_comparison( + deriver = _Deriver( + n_ic_timesteps=stepper_config.n_ic_timesteps, + derive_func=stepper.derive_func, + ) + run_dataset_comparison( aggregator=aggregator, - normalizer=stepper.normalizer, prediction_data=prediction_data, target_data=data, + deriver=deriver, writer=writer, + record_logs=record_logs, ) else: - timers = run_inference_evaluator( + run_inference( + predict=stepper.predict_paired, + data=data, aggregator=aggregator, writer=writer, - stepper=stepper, - data=data, + record_logs=record_logs, ) - final_flush_start_time = time.time() + timer.start("final_writer_flush") logging.info("Starting final flush of data writer") writer.flush() logging.info("Writing reduced metrics to disk in netcdf format.") - for name, ds in aggregator.get_datasets( - ("time_mean", "zonal_mean", "histogram") - ).items(): - coords = {k: v for k, v in data.coords.items() if k in ds.dims} - ds = ds.assign_coords(coords) - ds.to_netcdf(Path(config.experiment_dir) / f"{name}_diagnostics.nc") - - final_flush_duration = time.time() - final_flush_start_time - logging.info(f"Final writer flush duration: {final_flush_duration:.2f} seconds") - timers["final_writer_flush"] = final_flush_duration - - duration = time.time() - start_time - total_steps = config.n_forward_steps * config.loader.n_samples - total_steps_per_second = total_steps / duration - logging.info(f"Inference duration: {duration:.2f} seconds") - logging.info(f"Total steps per second: {total_steps_per_second:.2f} steps/second") - - step_logs = aggregator.get_inference_logs(label="inference") - wandb = WandB.get_instance() - if wandb.enabled: - logging.info("Starting logging of metrics to wandb") - duration_logs = { - "duration_seconds": duration, - "total_steps_per_second": total_steps_per_second, - } - wandb.log({**timers, **duration_logs}, step=0) - for i, log in enumerate(step_logs): - wandb.log(log, step=i, sleep=0.01) + write_reduced_metrics( + aggregator, + data.coords, + config.experiment_dir, + excluded=[ + "video", + ], + ) + timer.stop() + + timer.stop_outer("inference") + total_steps = config.n_forward_steps * config.loader.n_initial_conditions + inference_duration = timer.get_duration("inference") + wandb_logging_duration = timer.get_duration("wandb_logging") + total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration) + timer.log_durations() + logging.info( + "Total steps per second (ignoring wandb logging): " + f"{total_steps_per_second:.2f} steps/second" + ) + summary_logs = { + "total_steps_per_second": total_steps_per_second, + **timer.get_durations(), + **aggregator.get_summary_logs(), + } + record_logs([summary_logs]) config.clean_wandb() - return step_logs - if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/fme/fme/ace/inference/inference.py b/fme/fme/ace/inference/inference.py index 88a9173..6796d45 100644 --- a/fme/fme/ace/inference/inference.py +++ b/fme/fme/ace/inference/inference.py @@ -1,10 +1,8 @@ -import argparse +import copy import dataclasses import logging import os -import time -from pathlib import Path -from typing import Literal, Optional, Sequence, Tuple, Union +from typing import Literal, Optional, Sequence, Union import dacite import torch @@ -13,24 +11,27 @@ import fme import fme.core.logging_utils as logging_utils -from fme.ace.inference.data_writer import DataWriter, DataWriterConfig -from fme.ace.inference.loop import run_inference -from fme.core import SingleModuleStepper -from fme.core.aggregator.inference import InferenceAggregatorConfig -from fme.core.data_loading.data_typing import GriddedData, SigmaCoordinates -from fme.core.data_loading.getters import get_forcing_data -from fme.core.data_loading.inference import ( +from fme.ace.aggregator.inference import InferenceAggregatorConfig +from fme.ace.data_loading.batch_data import ( + BatchData, + InferenceGriddedData, + PrognosticState, +) +from fme.ace.data_loading.getters import get_forcing_data +from fme.ace.data_loading.inference import ( ExplicitIndices, ForcingDataLoaderConfig, InferenceInitialConditionIndices, TimestampList, ) +from fme.ace.inference.data_writer import DataWriter, DataWriterConfig +from fme.ace.inference.loop import write_reduced_metrics +from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig from fme.core.dicts import to_flat_dict +from fme.core.generics.inference import get_record_to_wandb, run_inference from fme.core.logging_utils import LoggingConfig from fme.core.ocean import OceanConfig -from fme.core.stepper import SingleModuleStepperConfig -from fme.core.typing_ import TensorMapping -from fme.core.wandb import WandB +from fme.core.timing import GlobalTimer from .evaluator import load_stepper, load_stepper_config, validate_time_coarsen_config @@ -44,12 +45,12 @@ class InitialConditionConfig: .. note:: The data specified under path should contain a time dimension of at least - length 1. If multiple times are present in the dataset specified by `path`, + length 1. If multiple times are present in the dataset specified by ``path``, the inference will start an ensemble simulation using each IC along a leading sample dimension. Specific times can be selected from the dataset - by using `start_indices`. + by using ``start_indices``. - Attributes: + Parameters: path: The path to the initial conditions dataset. engine: The engine used to open the dataset. start_indices: optional specification of the subset of @@ -79,7 +80,7 @@ def _subselect_initial_conditions(self, ds: xr.Dataset) -> xr.Dataset: def get_initial_condition( ds: xr.Dataset, prognostic_names: Sequence[str] -) -> Tuple[TensorMapping, xr.DataArray]: +) -> PrognosticState: """Given a dataset, extract a mapping of variables to tensors. and the time coordinate corresponding to the initial conditions. @@ -90,7 +91,7 @@ def get_initial_condition( prognostic_names: Names of prognostic variables to extract from the dataset. Returns: - A mapping of variable names to tensors and the time coordinate. + The initial condition and the time coordinate. """ initial_condition = {} for name in prognostic_names: @@ -100,16 +101,26 @@ def get_initial_condition( f"(n_samples, n_lat, n_lon). Got shape {ds[name].shape}." ) n_samples = ds[name].shape[0] - initial_condition[name] = torch.tensor(ds[name].values).to(fme.get_device()) + initial_condition[name] = torch.tensor(ds[name].values).unsqueeze(dim=1) if "time" not in ds: raise ValueError("Initial condition dataset must have a 'time' variable.") - initial_times = ds.time - if len(initial_times) != n_samples: + initial_times = xr.DataArray( + data=ds.time.values[:, None], + dims=["sample", "time"], + ) + if initial_times.shape[0] != n_samples: raise ValueError( "Length of 'time' variable must match first dimension of variables " - f"in initial condition dataset. Got {len(initial_times)} and {n_samples}." + f"in initial condition dataset. Got {initial_times.shape[0]} " + f"and {n_samples}." ) - return initial_condition, initial_times + + batch_data = BatchData.new_on_cpu( + data=initial_condition, + time=initial_times, + horizontal_dims=["lat", "lon"], + ) + return batch_data.get_start(prognostic_names, n_ic_timesteps=1) @dataclasses.dataclass @@ -117,7 +128,7 @@ class InferenceConfig: """ Configuration for running inference. - Attributes: + Parameters: experiment_dir: Directory to save results to. n_forward_steps: Number of steps to run the model forward for. checkpoint_path: Path to stepper checkpoint to load. @@ -167,46 +178,28 @@ def configure_wandb(self, env_vars: Optional[dict] = None, **kwargs): def clean_wandb(self): self.logging.clean_wandb(self.experiment_dir) - def configure_gcs(self): - self.logging.configure_gcs() - - def load_stepper( - self, area: torch.Tensor, sigma_coordinates: SigmaCoordinates - ) -> SingleModuleStepper: - """ - Args: - area: A tensor of shape (n_lat, n_lon) containing the area of - each grid cell. - sigma_coordinates: The sigma coordinates of the model. - """ + def load_stepper(self) -> SingleModuleStepper: logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}") - stepper = load_stepper( - self.checkpoint_path, - area=area, - sigma_coordinates=sigma_coordinates, - ocean_config=self.ocean, - ) + stepper = load_stepper(self.checkpoint_path, ocean_config=self.ocean) return stepper def load_stepper_config(self) -> SingleModuleStepperConfig: logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}") - return load_stepper_config(self.checkpoint_path) + return load_stepper_config(self.checkpoint_path, ocean_config=self.ocean) - def get_data_writer( - self, data: GriddedData, prognostic_names: Sequence[str] - ) -> DataWriter: + def get_data_writer(self, data: InferenceGriddedData) -> DataWriter: return self.data_writer.build( experiment_dir=self.experiment_dir, - n_samples=data.loader.dataset.n_samples, + # each batch contains all samples, for different times + n_initial_conditions=data.n_initial_conditions, n_timesteps=self.n_forward_steps, timestep=data.timestep, - prognostic_names=prognostic_names, - metadata=data.metadata, + variable_metadata=data.variable_metadata, coords=data.coords, ) -def main(yaml_config: str): +def main(yaml_config: str, segments: Optional[int] = None): with open(yaml_config, "r") as f: data = yaml.safe_load(f) config = dacite.from_dict( @@ -218,10 +211,19 @@ def main(yaml_config: str): os.makedirs(config.experiment_dir, exist_ok=True) with open(os.path.join(config.experiment_dir, "config.yaml"), "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) - run_inference_from_config(config) + if segments is None: + with GlobalTimer(): + return run_inference_from_config(config) + else: + config.configure_logging(log_filename="inference_out.log") + run_segmented_inference(config, segments) def run_inference_from_config(config: InferenceConfig): + timer = GlobalTimer.get_instance() + timer.start_outer("inference") + timer.start("initialization") + if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") @@ -234,86 +236,109 @@ def run_inference_from_config(config: InferenceConfig): logging_utils.log_versions() logging.info(f"Current device is {fme.get_device()}") - start_time = time.time() stepper_config = config.load_stepper_config() - data_requirements = stepper_config.get_forcing_data_requirements( - n_forward_steps=config.n_forward_steps + data_requirements = stepper_config.get_forcing_window_data_requirements( + n_forward_steps=config.forward_steps_in_memory ) logging.info("Loading initial condition data") - initial_condition, initial_times = get_initial_condition( + initial_condition = get_initial_condition( config.initial_condition.get_dataset(), stepper_config.prognostic_names ) + stepper = config.load_stepper() logging.info("Initializing forcing data loaded") data = get_forcing_data( - config.forcing_loader, - config.forward_steps_in_memory, - data_requirements, - initial_times, - ) - - stepper = config.load_stepper( - data.area_weights.to(fme.get_device()), - sigma_coordinates=data.sigma_coordinates.to(fme.get_device()), + config=config.forcing_loader, + total_forward_steps=config.n_forward_steps, + window_requirements=data_requirements, + initial_condition=initial_condition, + surface_temperature_name=stepper.surface_temperature_name, + ocean_fraction_name=stepper.ocean_fraction_name, ) if stepper.timestep != data.timestep: raise ValueError( f"Timestep of the loaded stepper, {stepper.timestep}, does not " f"match that of the forcing data, {data.timestep}." ) - aggregator = config.aggregator.build( - area_weights=data.area_weights.to(fme.get_device()), - sigma_coordinates=data.sigma_coordinates, - timestep=data.timestep, - n_timesteps=config.n_forward_steps + 1, - metadata=data.metadata, + gridded_operations=data.gridded_operations, + n_timesteps=config.n_forward_steps + stepper.n_ic_timesteps, + variable_metadata=data.variable_metadata, ) - writer = config.get_data_writer(data, stepper.prognostic_names) + writer = config.get_data_writer(data) + timer.stop() logging.info("Starting inference") - timers = run_inference( - stepper=stepper, - initial_condition=initial_condition, - forcing_data=data, + record_logs = get_record_to_wandb(label="inference") + run_inference( + predict=stepper.predict, + data=data, writer=writer, aggregator=aggregator, + record_logs=record_logs, ) - final_flush_start_time = time.time() + timer.start("final_writer_flush") logging.info("Starting final flush of data writer") writer.flush() - for name, ds in aggregator.get_datasets(("time_mean",)).items(): - coords = {k: v for k, v in data.coords.items() if k in ds.dims} - ds = ds.assign_coords(coords) - ds.to_netcdf(Path(config.experiment_dir) / f"{name}_diagnostics.nc") - final_flush_duration = time.time() - final_flush_start_time - logging.info(f"Final writer flush duration: {final_flush_duration:.2f} seconds") - timers["final_writer_flush"] = final_flush_duration - - duration = time.time() - start_time - total_steps = config.n_forward_steps * data.loader.dataset.n_samples - total_steps_per_second = total_steps / duration - logging.info(f"Inference duration: {duration:.2f} seconds") - logging.info(f"Total steps per second: {total_steps_per_second:.2f} steps/second") - - step_logs = aggregator.get_inference_logs(label="inference") - wandb = WandB.get_instance() - if wandb.enabled: - logging.info("Starting logging of metrics to wandb") - duration_logs = { - "duration_seconds": duration, - "total_steps_per_second": total_steps_per_second, - } - wandb.log({**timers, **duration_logs}, step=0) - for i, log in enumerate(step_logs): - wandb.log(log, step=i, sleep=0.01) + logging.info("Writing reduced metrics to disk in netcdf format.") + write_reduced_metrics(aggregator, data.coords, config.experiment_dir) + timer.stop() + + timer.stop_outer("inference") + total_steps = config.n_forward_steps * data.n_initial_conditions + inference_duration = timer.get_duration("inference") + wandb_logging_duration = timer.get_duration("wandb_logging") + total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration) + timer.log_durations() + logging.info( + "Total steps per second (ignoring wandb logging): " + f"{total_steps_per_second:.2f} steps/second" + ) + summary_logs = { + "total_steps_per_second": total_steps_per_second, + **timer.get_durations(), + **aggregator.get_summary_logs(), + } + record_logs([summary_logs]) config.clean_wandb() -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("yaml_config", type=str) - args = parser.parse_args() - main(yaml_config=args.yaml_config) +def run_segmented_inference(config: InferenceConfig, segments: int): + """Run inference in multiple segments. + + Args: + config: inference configuration to be used for each individual segment. The + provided initial condition configuration will only be used for the first + segment. + segments: total number of segments desired. Only missing segments will be run. + + Note: + This is useful when running very long simulations or when saving a large + amount of output data to disk. The simulation outputs will be split across + multiple folders, each corresponding to one of the segments and labeled by + the segment number. + """ + logging.info( + f"Starting segmented inference with {segments} segments. " + f"Saving to {config.experiment_dir}." + ) + config_copy = copy.deepcopy(config) + original_wandb_name = os.environ.get("WANDB_NAME") + for segment in range(segments): + segment_label = f"segment_{segment:04d}" + segment_dir = os.path.join(config.experiment_dir, segment_label) + restart_path = os.path.join(segment_dir, "restart.nc") + if os.path.exists(restart_path): + logging.info(f"Skipping segment {segment} because it has already been run.") + else: + logging.info(f"Running segment {segment}.") + config_copy.experiment_dir = segment_dir + if original_wandb_name is not None: + os.environ["WANDB_NAME"] = f"{original_wandb_name}-{segment_label}" + with GlobalTimer(): + run_inference_from_config(config_copy) + config_copy.initial_condition = InitialConditionConfig( + path=restart_path, engine="netcdf4" + ) diff --git a/fme/fme/ace/inference/loop.py b/fme/fme/ace/inference/loop.py index 00f8118..87e4c33 100644 --- a/fme/fme/ace/inference/loop.py +++ b/fme/fme/ace/inference/loop.py @@ -1,376 +1,113 @@ +import abc import logging -import time -from collections import defaultdict -from typing import Any, Dict, Mapping, Optional, Union +from pathlib import Path +from typing import Callable, Iterable, Mapping, Optional, Union -import torch -import xarray as xr +import numpy as np -from fme.core import SingleModuleStepper -from fme.core.aggregator.inference.main import ( +from fme.ace.aggregator.inference.main import ( InferenceAggregator, InferenceEvaluatorAggregator, ) -from fme.core.data_loading.data_typing import GriddedData -from fme.core.data_loading.utils import BatchData -from fme.core.device import get_device -from fme.core.normalizer import StandardNormalizer -from fme.core.optimization import NullOptimization -from fme.core.stepper import SteppedData -from fme.core.typing_ import TensorMapping - -from .data_writer import DataWriter, NullDataWriter, PairedDataWriter -from .derived_variables import ( - compute_derived_quantities, - compute_stepped_derived_quantities, -) +from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState +from fme.ace.inference.data_writer import PairedDataWriter +from fme.core.generics.aggregator import InferenceAggregatorABC, InferenceLogs +from fme.core.generics.data import InferenceDataABC +from fme.core.generics.inference import get_record_to_wandb +from fme.core.generics.writer import NullDataWriter +from fme.core.timing import GlobalTimer -class WindowStitcher: +class DeriverABC(abc.ABC): """ - Handles stitching together the windows of data from the inference loop. - - For example, handles passing in windows to data writers which combine - them together into a continuous series, and handles storing prognostic - variables from the end of a window to use as the initial condition for - the next window. + Abstract base class for processing data during dataset comparison. """ - def __init__( - self, - n_forward_steps: int, - writer: Union[PairedDataWriter, NullDataWriter], - ): - self.i_time = 0 - self.n_forward_steps = n_forward_steps - self.writer = writer - # tensors have shape [n_sample, n_lat, n_lon] with no time axis - self._initial_condition: Optional[Mapping[str, torch.Tensor]] = None - - def append( - self, - data: Dict[str, torch.tensor], - gen_data: Dict[str, torch.tensor], - batch_times: xr.DataArray, - ) -> None: - """ - Appends a time segment of data to the ensemble batch. - - Args: - data: The reference data for the current time segment, tensors - should have shape [n_sample, n_time, n_lat, n_lon] - gen_data: The generated data for the current time segment, tensors - should have shape [n_sample, n_time, n_lat, n_lon] - batch_times: Time coordinates for each sample in the batch. - """ - tensor_shape = next(data.values().__iter__()).shape - self.writer.append_batch( - target=data, - prediction=gen_data, - start_timestep=self.i_time, - batch_times=batch_times, - ) - self.i_time += tensor_shape[1] - self._initial_condition = {key: value[:, -1] for key, value in data.items()} - for key, value in gen_data.items(): - self._initial_condition[key] = value[:, -1] - for key, value in self._initial_condition.items(): - self._initial_condition[key] = value.detach().cpu() - - def apply_initial_condition(self, data: Mapping[str, torch.Tensor]): - """ - Applies the last recorded state of the batch as the initial condition for - the next segment of the timeseries. + @abc.abstractmethod + def get_forward_data( + self, data: BatchData, compute_derived_variables: bool = False + ) -> BatchData: ... - Args: - data: The data to apply the initial condition to, tensors should have - shape [n_sample, n_time, n_lat, n_lon] and the first value along - the time axis will be replaced with the last value from the - previous segment. - """ - if self.i_time > self.n_forward_steps: - raise ValueError( - "Cannot apply initial condition after " - "the last segment has been appended, currently at " - f"time index {self.i_time} " - f"with {self.n_forward_steps} max forward steps." - ) - if self._initial_condition is not None: - for key, value in data.items(): - value[:, 0] = self._initial_condition[key].to(value.device) - - def save_initial_condition( - self, - ic_data: Dict[str, torch.Tensor], - ic_time: xr.DataArray, - ): - self.writer.save_initial_condition(ic_data, ic_time) + @property + @abc.abstractmethod + def n_ic_timesteps(self) -> int: ... -def _inference_internal_loop( - stepped: SteppedData, - i_time: int, - aggregator: InferenceEvaluatorAggregator, - stitcher: WindowStitcher, - batch_times: xr.DataArray, +def write_reduced_metrics( + aggregator: Union[InferenceEvaluatorAggregator, InferenceAggregator], + data_coords: Mapping[str, np.ndarray], + path: str, + excluded: Optional[Iterable[str]] = None, ): - """Do operations that need to be done on each time step of the inference loop. - - This function exists to de-duplicate code between run_inference_evaluator and - run_dataset_comparison.""" - - # The first data window includes the IC, while subsequent windows don't. - # The aggregators use the full first window including IC. - # The data writers exclude the IC from the first window. - if i_time == 0: - i_time_aggregator = i_time - stepped_no_ic = stepped.remove_initial_condition() - stitcher.save_initial_condition( - ic_data={k: v[:, 0] for k, v in stepped.target_data.items()}, - ic_time=batch_times.isel(time=0), - ) - batch_times_no_ic = batch_times.isel(time=slice(1, None)) - else: - i_time_aggregator = i_time + 1 - stepped_no_ic = stepped - batch_times_no_ic = batch_times - - # record raw data for the batch, and store the final state - # for the next segment - # Do not include the initial condition in the data writers - stitcher.append( - stepped_no_ic.target_data, stepped_no_ic.gen_data, batch_times_no_ic - ) - - # record metrics, includes the initial condition - aggregator.record_batch( - loss=float(stepped.metrics["loss"]), - time=batch_times, - target_data=stepped.target_data, - gen_data=stepped.gen_data, - target_data_norm=stepped.target_data_norm, - gen_data_norm=stepped.gen_data_norm, - i_time_start=i_time_aggregator, - ) - - -def _to_device( - data: Mapping[str, torch.Tensor], device: torch.device -) -> Dict[str, Any]: - return {key: value.to(device) for key, value in data.items()} - - -def run_inference( - stepper: SingleModuleStepper, - initial_condition: TensorMapping, - forcing_data: GriddedData, - writer: DataWriter, - aggregator: InferenceAggregator, -) -> Dict[str, float]: - """Run extended inference loop given initial condition and forcing data. + """ + Write the reduced metrics to disk. Each sub-aggregator will write a netCDF file + if its `get_dataset` method returns a non-empty dataset. Args: - stepper: The model to run inference with. - initial_condition: Mapping of prognostic names to initial condition tensors of - shape (n_sample, n_lat, n_lon). - forcing_data: GriddedData object which includes a DataLoader which will provide - windows of forcing data appropriately aligned with the initial condition. - writer: Data writer for saving the inference results to disk. - - Returns: - Execution time in seconds for each step of the inference loop. + aggregator: The aggregator to write metrics from. + data_coords: Coordinates to assign to the datasets. + path: Path to write the metrics to. + excluded: Names of metrics to exclude from writing. """ - with torch.no_grad(): - timers: Dict[str, float] = defaultdict(float) - current_time = time.time() - i_time = 0 - window_forcing: BatchData - for window_forcing in forcing_data.loader: - timers["data_loading"] += time.time() - current_time - current_time = time.time() - forward_steps_in_memory = list(window_forcing.data.values())[0].size(1) - 1 - logging.info( - f"Inference: starting window spanning {i_time}" - f" to {i_time + forward_steps_in_memory} steps." - ) - window_forcing_data = _to_device(window_forcing.data, get_device()) - prediction = stepper.predict( - initial_condition, window_forcing_data, forward_steps_in_memory - ) - timers["run_on_batch"] += time.time() - current_time - - # Replicates the timestep range of the stepped.target_data - # used in the evaluator computation of derived quantities, which drops - # the initial condition. - forcing_data_at_prediction_steps = { - k: window_forcing_data[k][:, 1:] for k in window_forcing_data - } - prediction = compute_derived_quantities( - prediction, - forcing_data.sigma_coordinates, - forcing_data.timestep, - forcing_data=forcing_data_at_prediction_steps, - ) - - forward_times = window_forcing.times.isel(time=slice(1, None)) - writer.append_batch(prediction, i_time, forward_times) - aggregator.record_batch( - time=forward_times, data=prediction, i_time_start=i_time + 1 - ) - timers["writer_and_aggregator"] += time.time() - current_time - current_time = time.time() - initial_condition = { - k: prediction[k][:, -1] for k in stepper.prognostic_names - } - i_time += forward_steps_in_memory - - for name, duration in timers.items(): - logging.info(f"{name} duration: {duration:.2f}s") - return timers - - -def run_inference_evaluator( - aggregator: InferenceEvaluatorAggregator, - stepper: SingleModuleStepper, - data: GriddedData, - writer: Optional[Union[PairedDataWriter, NullDataWriter]] = None, -) -> Dict[str, float]: - if writer is None: - writer = NullDataWriter() - n_forward_steps = data.loader.dataset.n_forward_steps - stitcher = WindowStitcher(n_forward_steps, writer) - - with torch.no_grad(): - # We have data batches with long windows, where all data for a - # given batch does not fit into memory at once, so we window it in time - # and run the model on each window in turn. - # - # We process each time window and keep track of the - # final state. We then use this as the initial condition - # for the next time window. - - timers: Dict[str, float] = defaultdict(float) - current_time = time.time() - i_time = 0 - for i, window_batch_data in enumerate(data.loader): - timers["data_loading"] += time.time() - current_time - current_time = time.time() - forward_steps_in_memory = ( - list(window_batch_data.data.values())[0].size(1) - 1 - ) - logging.info( - f"Inference: starting window spanning {i_time}" - f" to {i_time + forward_steps_in_memory} steps, " - f"out of total {n_forward_steps}." - ) - device = get_device() - window_data = _to_device(window_batch_data.data, device) - - stitcher.apply_initial_condition(window_data) - - stepped = stepper.run_on_batch( - window_data, - NullOptimization(), - n_forward_steps=forward_steps_in_memory, - ) - - # Prepend initial (pre-first-timestep) output for the first window - if i == 0: - ( - initial_condition, - normed_initial_condition, - ) = stepper.get_initial_condition(window_data) - stepped = stepped.prepend_initial_condition( - initial_condition, normed_initial_condition - ) - batch_times = window_batch_data.times - else: - batch_times = window_batch_data.times.isel(time=slice(1, None)) - stepped = compute_stepped_derived_quantities( - stepped, - data.sigma_coordinates, - data.timestep, - # forcing inputs are in target data but not gen_data - forcing_data=stepped.target_data, - ) - timers["run_on_batch"] += time.time() - current_time - current_time = time.time() - _inference_internal_loop(stepped, i_time, aggregator, stitcher, batch_times) - timers["writer_and_aggregator"] += time.time() - current_time - current_time = time.time() - i_time += forward_steps_in_memory - - for name, duration in timers.items(): - logging.info(f"{name} duration: {duration:.2f}s") - return timers + for name, ds in aggregator.get_datasets(excluded_aggregators=excluded).items(): + if len(ds) > 0: + coords = {k: v for k, v in data_coords.items() if k in ds.dims} + ds = ds.assign_coords(coords) + ds.to_netcdf(Path(path) / f"{name}_diagnostics.nc") def run_dataset_comparison( - aggregator: InferenceEvaluatorAggregator, - normalizer: StandardNormalizer, - prediction_data: GriddedData, - target_data: GriddedData, + aggregator: InferenceAggregatorABC[PairedData, PairedData], + prediction_data: InferenceDataABC[PrognosticState, BatchData], + target_data: InferenceDataABC[PrognosticState, BatchData], + deriver: DeriverABC, writer: Optional[Union[PairedDataWriter, NullDataWriter]] = None, -) -> Dict[str, float]: + record_logs: Optional[Callable[[InferenceLogs], None]] = None, +): + if record_logs is None: + record_logs = get_record_to_wandb(label="inference") if writer is None: writer = NullDataWriter() - n_forward_steps = target_data.loader.dataset.n_forward_steps - stitcher = WindowStitcher(n_forward_steps, writer) - device = get_device() - # We have data batches with long windows, where all data for a - # given batch does not fit into memory at once, so we window it in time - # and run the model on each window in turn. - # - # We process each time window and keep track of the - # final state. We then use this as the initial condition - # for the next time window. - timers: Dict[str, float] = defaultdict(float) - current_time = time.time() + timer = GlobalTimer.get_instance() + timer.start("data_loading") i_time = 0 + n_windows = min(len(prediction_data.loader), len(target_data.loader)) for i, (pred, target) in enumerate(zip(prediction_data.loader, target_data.loader)): - timers["data_loading"] += time.time() - current_time - current_time = time.time() + timer.stop() + if i_time == 0: + with timer.context("aggregator"): + logs = aggregator.record_initial_condition( + initial_condition=PairedData.from_batch_data( + prediction=prediction_data.initial_condition.as_batch_data(), + target=target_data.initial_condition.as_batch_data(), + ), + ) + with timer.context("wandb_logging"): + record_logs(logs) + forward_steps_in_memory = list(pred.data.values())[0].size(1) - 1 logging.info( - f"Inference: starting window spanning {i_time}" - f" to {i_time + forward_steps_in_memory} steps," - f" out of total {n_forward_steps}." - ) - pred_window_data = _to_device(pred.data, device) - target_window_data = _to_device(target.data, device) - stepped = SteppedData( - {"loss": torch.tensor(float("nan"))}, - pred_window_data, - target_window_data, - normalizer.normalize(pred_window_data), - normalizer.normalize(target_window_data), - ) - stepped = compute_stepped_derived_quantities( - stepped, target_data.sigma_coordinates, target_data.timestep + f"Inference: Processing window {i + 1} of {n_windows}" + f" spanning {i_time} to {i_time + forward_steps_in_memory} steps." ) + pred = deriver.get_forward_data(pred, compute_derived_variables=True) + target = deriver.get_forward_data(target, compute_derived_variables=True) + paired_data = PairedData.from_batch_data(prediction=pred, target=target) + + with timer.context("data_writer"): + writer.append_batch( + batch=paired_data, + ) + with timer.context("aggregator"): + logs = aggregator.record_batch( + data=paired_data, + ) - # Windows here all include an initial condition at start. - # Remove IC and time coord for windows >0 to be consistent with - # run_on_batch outputs before passing to the shared _inference_internal_loop. - if i > 0: - stepped = stepped.remove_initial_condition() - target_times = target.times.isel(time=slice(1, None)) - else: - target_times = target.times + with timer.context("wandb_logging"): + record_logs(logs) - timers["run_on_batch"] += time.time() - current_time - current_time = time.time() - _inference_internal_loop( - stepped, - i_time, - aggregator, - stitcher, - target_times, - ) - timers["writer_and_aggregator"] += time.time() - current_time - current_time = time.time() + timer.start("data_loading") i_time += forward_steps_in_memory - for name, duration in timers.items(): - logging.info(f"{name} duration: {duration:.2f}s") - return timers + + timer.stop() diff --git a/fme/fme/ace/inference/stepper_test_data b/fme/fme/ace/inference/stepper_test_data index b9fc36c..4c3a594 100644 Binary files a/fme/fme/ace/inference/stepper_test_data and b/fme/fme/ace/inference/stepper_test_data differ diff --git a/fme/fme/ace/inference/test_derived_variables.py b/fme/fme/ace/inference/test_derived_variables.py index f06b3b0..4c93012 100644 --- a/fme/fme/ace/inference/test_derived_variables.py +++ b/fme/fme/ace/inference/test_derived_variables.py @@ -1,70 +1,114 @@ import datetime +import numpy as np import pytest import torch +import xarray as xr -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.stepper import SteppedData +from fme.ace.stepper import TrainOutput +from fme.core.climate_data import ClimateData +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.typing_ import TensorDict, TensorMapping -from .derived_variables import ( - DerivedVariableRegistryEntry, - _compute_derived_variable, - compute_stepped_derived_quantities, -) +from .derived_variables import _compute_derived_variable, compute_derived_quantities TIMESTEP = datetime.timedelta(hours=6) def test_compute_derived_variable(): fake_data = {"PRESsfc": torch.tensor([1.0]), "PRATEsfc": torch.tensor([2.0])} - sigma_coordinates = None - derived_variable = DerivedVariableRegistryEntry( - func=lambda data, *_: data.surface_pressure + data.precipitation_rate + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.tensor([0.0, 0.0]), bk=torch.tensor([0.0, 1.0]) ) + + def _derived_variable_func(data: ClimateData, *_) -> torch.Tensor: + return data.surface_pressure + data.precipitation_rate + output_data = _compute_derived_variable( - fake_data, sigma_coordinates, TIMESTEP, "c", derived_variable + fake_data, vertical_coordinate, TIMESTEP, "c", _derived_variable_func ) torch.testing.assert_close(output_data["c"], torch.tensor([3.0])) def test_compute_derived_variable_raises_value_error_when_overwriting(): fake_data = {"PRESsfc": torch.tensor([1.0]), "PRATEsfc": torch.tensor([2.0])} - sigma_coordinates = None - derived_variable = DerivedVariableRegistryEntry( - func=lambda data, _: data.surface_pressure + data.precipitation_rate + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.tensor([0.0, 0.0]), bk=torch.tensor([0.0, 1.0]) ) + + def add_surface_pressure_and_precipitation(data: ClimateData, *_) -> torch.Tensor: + return data.surface_pressure + data.precipitation_rate + + derived_variable_func = add_surface_pressure_and_precipitation with pytest.raises(ValueError): _compute_derived_variable( - fake_data, sigma_coordinates, TIMESTEP, "PRATEsfc", derived_variable + fake_data, vertical_coordinate, TIMESTEP, "PRATEsfc", derived_variable_func ) -def test_compute_derived_quantities(): +@pytest.mark.parametrize("dataset", ["fv3", "e3sm"]) +def test_compute_derived_quantities(dataset: str): torch.manual_seed(0) - fake_data = { - "PRESsfc": 10.0 + torch.rand(2, 3, 4, 8), - "specific_total_water_0": torch.rand(2, 3, 4, 8), - "specific_total_water_1": torch.rand(2, 3, 4, 8), - "PRATEsfc": torch.rand(2, 3, 4, 8), - "LHTFLsfc": torch.rand(2, 3, 4, 8), - "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8), - } - data = SteppedData( - metrics={"loss": torch.tensor(0.0)}, - gen_data=fake_data, - target_data=fake_data, - gen_data_norm=fake_data, - target_data_norm=fake_data, - ) - sigma_coordinates = SigmaCoordinates( + + if dataset == "fv3": + fake_data = { + "PRESsfc": 10.0 + torch.rand(2, 3, 4, 8), + "specific_total_water_0": torch.rand(2, 3, 4, 8), + "specific_total_water_1": torch.rand(2, 3, 4, 8), + "PRATEsfc": torch.rand(2, 3, 4, 8), + "LHTFLsfc": torch.rand(2, 3, 4, 8), + "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8), + "DSWRFtoa": torch.rand(2, 3, 4, 8), + "USWRFtoa": torch.rand(2, 3, 4, 8), + "ULWRFtoa": torch.rand(2, 3, 4, 8), + } + gen_data = fake_data.copy() + del gen_data["DSWRFtoa"] + + if dataset == "e3sm": + fake_data = { + "PS": 10.0 + torch.rand(2, 3, 4, 8), + "specific_total_water_0": torch.rand(2, 3, 4, 8), + "specific_total_water_1": torch.rand(2, 3, 4, 8), + "surface_precipitation_rate": torch.rand(2, 3, 4, 8), + "LHFLX": torch.rand(2, 3, 4, 8), + "tendency_of_total_water_path_due_to_advection": torch.rand(2, 3, 4, 8), + "SOLIN": torch.rand(2, 3, 4, 8), + "top_of_atmos_upward_shortwave_flux": torch.rand(2, 3, 4, 8), + "FLUT": torch.rand(2, 3, 4, 8), + } + gen_data = fake_data.copy() + del gen_data["SOLIN"] + + vertical_coordinate = HybridSigmaPressureCoordinate( ak=torch.tensor([0.0, 0.5, 0.0]), bk=torch.tensor([0.0, 0.5, 1.0]), ) - out_data = compute_stepped_derived_quantities(data, sigma_coordinates, TIMESTEP) + + def derive_func(data: TensorMapping, forcing_data: TensorMapping) -> TensorDict: + updated = compute_derived_quantities( + dict(data), + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + forcing_data=dict(forcing_data), + ) + return updated + + data = TrainOutput( + metrics={"loss": torch.tensor(0.0)}, + gen_data=gen_data, + target_data=fake_data, + time=xr.DataArray(np.zeros((2, 3)), dims=["sample", "time"]), + normalize=lambda x: x, + derive_func=derive_func, + ) + out_data = data.compute_derived_variables() for name in ( "total_water_path_budget_residual", "total_water_path", "surface_pressure_due_to_dry_air", + "surface_pressure_due_to_dry_air_absolute_tendency", + "net_energy_flux_toa_into_atmosphere", ): assert name in out_data.gen_data assert name in out_data.target_data diff --git a/fme/fme/ace/inference/test_evaluator.py b/fme/fme/ace/inference/test_evaluator.py index a2227ae..99cad2a 100644 --- a/fme/fme/ace/inference/test_evaluator.py +++ b/fme/fme/ace/inference/test_evaluator.py @@ -1,8 +1,8 @@ -import contextlib import dataclasses import datetime +import os import pathlib -from typing import List, Tuple +from typing import List import dacite import numpy as np @@ -11,25 +11,29 @@ import xarray as xr import yaml +from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig +from fme.ace.data_loading.inference import ( + InferenceDataLoaderConfig, + InferenceInitialConditionIndices, +) from fme.ace.inference.data_writer import DataWriterConfig from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig -from fme.ace.inference.derived_variables import compute_stepped_derived_quantities from fme.ace.inference.evaluator import InferenceEvaluatorConfig, main from fme.ace.registry import ModuleSelector +from fme.ace.stepper import SingleModuleStepperConfig, TrainOutput +from fme.ace.testing import DimSizes, FV3GFSData, MonthlyReferenceData from fme.core import metrics -from fme.core.aggregator.inference import InferenceEvaluatorAggregatorConfig, annual -from fme.core.data_loading.config import XarrayDataConfig -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.data_loading.inference import ( - InferenceDataLoaderConfig, - InferenceInitialConditionIndices, -) +from fme.core.coordinates import DimSize, HybridSigmaPressureCoordinate +from fme.core.dataset.config import XarrayDataConfig from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from fme.core.logging_utils import LoggingConfig -from fme.core.normalizer import FromStateNormalizer +from fme.core.normalizer import NormalizationConfig from fme.core.ocean import Ocean, OceanConfig -from fme.core.stepper import SingleModuleStepperConfig, SteppedData -from fme.core.testing import DimSizes, FV3GFSData, MonthlyReferenceData, mock_wandb +from fme.core.testing import mock_wandb +from fme.core.typing_ import TensorDict, TensorMapping + +from .derived_variables import compute_derived_quantities DIR = pathlib.Path(__file__).parent TIMESTEP = datetime.timedelta(hours=6) @@ -40,41 +44,34 @@ def forward(self, x): return x + 1 -@contextlib.contextmanager -def patch_annual_aggregator_min_samples(value): - original = annual.MIN_SAMPLES - try: - annual.MIN_SAMPLES = value - yield - finally: - annual.MIN_SAMPLES = original - - def save_plus_one_stepper( path: pathlib.Path, - names: List[str], + in_names: List[str], + out_names: List[str], mean: float, std: float, - data_shape: Tuple[int, int, int], + data_shape: List[int], timestep: datetime.timedelta = TIMESTEP, + nz_interface: int = 7, ): + all_names = list(set(in_names).union(out_names)) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}), - in_names=["var"], - out_names=["var"], - normalization=FromStateNormalizer( - state={ - "means": {name: mean for name in names}, - "stds": {name: std for name in names}, - } + in_names=in_names, + out_names=out_names, + normalization=NormalizationConfig( + means={name: mean for name in all_names}, + stds={name: std for name in all_names}, ), ) area = torch.ones(data_shape[-2:], device=get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(nz_interface), bk=torch.arange(nz_interface) + ) stepper = config.get_stepper( - img_shape=data_shape[-2:], - area=area, - sigma_coordinates=sigma_coordinates, + img_shape=(data_shape[-2], data_shape[-1]), + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=timestep, ) torch.save({"stepper": stepper.get_state()}, path) @@ -87,12 +84,12 @@ def test_inference_backwards_compatibility(tmp_path: pathlib.Path): """ in_names = ["var"] out_names = ["var"] - all_names = list(set(in_names).union(out_names)) stepper_path = DIR / "stepper_test_data" + + horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)] dim_sizes = DimSizes( n_time=8, - n_lat=4, - n_lon=8, + horizontal=horizontal, nz_interface=2, ) std = 1.0 @@ -101,17 +98,19 @@ def test_inference_backwards_compatibility(tmp_path: pathlib.Path): # to re-generate, just delete the data and run the test (it will fail) save_plus_one_stepper( stepper_path, - names=all_names, + in_names, + out_names, mean=0.0, std=std, - data_shape=dim_sizes.shape_2d, + data_shape=dim_sizes.shape_nd, ) assert False, "stepper_test_data did not exist, it has been created" use_prediction_data = False n_forward_steps = 2 inference_helper( tmp_path, - all_names, + in_names, + out_names, use_prediction_data, dim_sizes, n_forward_steps, @@ -129,12 +128,12 @@ def test_inference_plus_one_model( ): in_names = ["var"] out_names = ["var"] - all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper" + + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] dim_sizes = DimSizes( n_time=n_forward_steps + 1, - n_lat=16, - n_lon=32, + horizontal=horizontal, nz_interface=4, ) if use_prediction_data: @@ -144,15 +143,17 @@ def test_inference_plus_one_model( std = 1.0 save_plus_one_stepper( stepper_path, - names=all_names, + in_names, + out_names, mean=0.0, std=std, - data_shape=dim_sizes.shape_2d, + data_shape=dim_sizes.shape_nd, timestep=datetime.timedelta(days=20), ) inference_helper( tmp_path, - all_names, + in_names, + out_names, use_prediction_data, dim_sizes, n_forward_steps, @@ -164,21 +165,26 @@ def test_inference_plus_one_model( def inference_helper( tmp_path, - all_names, + in_names, + out_names, use_prediction_data, dim_sizes: DimSizes, n_forward_steps, stepper_path, timestep: datetime.timedelta, save_monthly_files: bool = True, + derived_names: List[str] = [], ): time_varying_values = [float(i) for i in range(dim_sizes.n_time)] + all_names = list(set(in_names).union(out_names)) + forcing_names = list(set(in_names).difference(out_names)) data = FV3GFSData( path=tmp_path, names=all_names, dim_sizes=dim_sizes, time_varying_values=time_varying_values, timestep_days=timestep.total_seconds() / 86400, + save_vertical_coordinate=False, ) if use_prediction_data: prediction_data = data.inference_data_loader_config @@ -192,8 +198,7 @@ def inference_helper( names=all_names, dim_sizes=DimSizes( n_time=48, - n_lat=dim_sizes.n_lat, - n_lon=dim_sizes.n_lon, + horizontal=dim_sizes.horizontal, nz_interface=1, ), n_ensemble=3, @@ -228,18 +233,39 @@ def inference_helper( yaml.dump(dataclasses.asdict(config), f) with mock_wandb() as wandb: - inference_logs = main( + wandb.configure(log_to_wandb=True) + main( yaml_config=str(config_filename), ) + wandb_logs = wandb.get_logs() + + all_out_names = out_names + derived_names + + n_ic_timesteps = 1 + summary_log_step = 1 + assert len(wandb_logs) == n_ic_timesteps + config.n_forward_steps + summary_log_step + for i in range(n_ic_timesteps + config.n_forward_steps): + log = wandb_logs[i] + for var in all_out_names: + if i == 0 and var not in in_names: + assert f"inference/mean/weighted_rmse/{var}" not in log + else: + # if these are off by something like 90% then probably the stepper + # is being used instead of the prediction_data + assert log[f"inference/mean/weighted_rmse/{var}"] == 0.0 + assert log[f"inference/mean/weighted_bias/{var}"] == 0.0 + + var = list(set(in_names).difference(forcing_names))[0] + + if not use_prediction_data: + initial_condition_ds = xr.open_dataset( + tmp_path / "initial_condition.nc", decode_timedelta=False + ) + for dim_name in ["lat", "lon"]: + assert dim_name in initial_condition_ds.dims + assert dim_name in initial_condition_ds.data_vars[var].dims + assert dim_name in initial_condition_ds.coords - # Unlike the data writer outputs, aggregator logs include IC step - assert len(inference_logs) == config.n_forward_steps + 1 - assert len(wandb.get_logs()) == len(inference_logs) - for log in inference_logs: - # if these are off by something like 90% then probably the stepper - # is being used instead of the prediction_data - assert log["inference/mean/weighted_rmse/var"] == 0.0 - assert log["inference/mean/weighted_bias/var"] == 0.0 prediction_ds = xr.open_dataset( tmp_path / "autoregressive_predictions.nc", decode_timedelta=False, @@ -248,90 +274,103 @@ def inference_helper( assert len(prediction_ds["time"]) == config.n_forward_steps for i in range(config.n_forward_steps - 1): np.testing.assert_allclose( - prediction_ds["var"].isel(time=i).values + 1, - prediction_ds["var"].isel(time=i + 1).values, + prediction_ds[var].isel(time=i).values + 1, + prediction_ds[var].isel(time=i + 1).values, ) - assert not np.any(np.isnan(prediction_ds["var"].isel(time=i + 1).values)) + assert not np.any(np.isnan(prediction_ds[var].isel(time=i + 1).values)) assert "lat" in prediction_ds.coords assert "lon" in prediction_ds.coords - restart_ds = xr.open_dataset( - tmp_path / "restart.nc", decode_timedelta=False, decode_times=False - ) - np.testing.assert_allclose( - prediction_ds["var"].isel(time=-1).values, - restart_ds["var"].values, - ) + if use_prediction_data: + assert not os.path.exists(tmp_path / "restart.nc") + assert not os.path.exists(tmp_path / "initial_condition.nc") + else: + restart_ds = xr.open_dataset( + tmp_path / "restart.nc", decode_timedelta=False, decode_times=False + ) + np.testing.assert_allclose( + prediction_ds[var].isel(time=-1).values, + restart_ds[var].values, + ) - ic_ds = xr.open_dataset( - tmp_path / "initial_condition.nc", decode_timedelta=False, decode_times=False - ) - np.testing.assert_allclose(ic_ds["var"].values, 0.0) + ic_ds = xr.open_dataset( + tmp_path / "initial_condition.nc", + decode_timedelta=False, + decode_times=False, + ) + np.testing.assert_allclose(ic_ds[var].values, 0.0) metric_ds = xr.open_dataset(tmp_path / "reduced_autoregressive_predictions.nc") - assert "var" in metric_ds.data_vars - assert metric_ds.data_vars["var"].attrs["units"] == "m" - assert metric_ds.data_vars["var"].attrs["long_name"] == "ensemble mean of var" - assert "rmse_var" in metric_ds.data_vars - assert metric_ds.data_vars["rmse_var"].attrs["units"] == "m" + assert var in metric_ds.data_vars + assert metric_ds.data_vars[var].attrs["units"] == "m" + assert metric_ds.data_vars[var].attrs["long_name"] == f"ensemble mean of {var}" + assert f"rmse_{var}" in metric_ds.data_vars + assert metric_ds.data_vars[f"rmse_{var}"].attrs["units"] == "m" assert ( - metric_ds.data_vars["rmse_var"].attrs["long_name"] - == "root mean squared error of var" - ) - assert "bias_var" in metric_ds.data_vars - assert metric_ds.data_vars["bias_var"].attrs["units"] == "m" - assert "min_err_var" in metric_ds.data_vars - assert metric_ds.data_vars["min_err_var"].attrs["units"] == "m" - assert "max_err_var" in metric_ds.data_vars - assert metric_ds.data_vars["max_err_var"].attrs["units"] == "m" - assert "gen_var_var" in metric_ds.data_vars - assert metric_ds.data_vars["gen_var_var"].attrs["units"] == "" + metric_ds.data_vars[f"rmse_{var}"].attrs["long_name"] + == f"root mean squared error of {var}" + ) + assert f"bias_{var}" in metric_ds.data_vars + assert metric_ds.data_vars[f"bias_{var}"].attrs["units"] == "m" + assert f"min_err_{var}" in metric_ds.data_vars + assert metric_ds.data_vars[f"min_err_{var}"].attrs["units"] == "m" + assert f"max_err_{var}" in metric_ds.data_vars + assert metric_ds.data_vars[f"max_err_{var}"].attrs["units"] == "m" + assert f"gen_var_{var}" in metric_ds.data_vars + assert metric_ds.data_vars[f"gen_var_{var}"].attrs["units"] == "" assert ( - metric_ds.data_vars["gen_var_var"].attrs["long_name"] - == "prediction variance of var as fraction of target variance" + metric_ds.data_vars[f"gen_var_{var}"].attrs["long_name"] + == f"prediction variance of {var} as fraction of target variance" ) assert "lat" in metric_ds.coords assert "lon" in metric_ds.coords time_mean_diagnostics = xr.open_dataset(tmp_path / "time_mean_diagnostics.nc") actual_var_names = sorted([str(k) for k in time_mean_diagnostics.keys()]) - assert len(actual_var_names) == 2 - assert "bias_map-var" in actual_var_names - assert time_mean_diagnostics.data_vars["bias_map-var"].attrs["units"] == "m" - assert "gen_map-var" in actual_var_names - assert time_mean_diagnostics.data_vars["gen_map-var"].attrs["units"] == "m" + assert len(actual_var_names) == 2 * len(all_out_names) + assert f"bias_map-{var}" in actual_var_names + assert time_mean_diagnostics.data_vars[f"bias_map-{var}"].attrs["units"] == "m" + assert f"gen_map-{var}" in actual_var_names + assert time_mean_diagnostics.data_vars[f"gen_map-{var}"].attrs["units"] == "m" assert len(time_mean_diagnostics.coords) == 2 assert "lat" in time_mean_diagnostics.coords assert "lon" in time_mean_diagnostics.coords zonal_mean_diagnostics = xr.open_dataset(tmp_path / "zonal_mean_diagnostics.nc") actual_var_names = sorted([str(k) for k in zonal_mean_diagnostics.keys()]) - assert len(actual_var_names) == 2 - assert "error-var" in actual_var_names - assert zonal_mean_diagnostics.data_vars["error-var"].attrs["units"] == "m" - assert "gen-var" in actual_var_names - assert zonal_mean_diagnostics.data_vars["gen-var"].attrs["units"] == "" + assert len(actual_var_names) == 2 * len(all_out_names) + assert f"error-{var}" in actual_var_names + assert zonal_mean_diagnostics.data_vars[f"error-{var}"].attrs["units"] == "m" + assert f"gen-{var}" in actual_var_names + assert zonal_mean_diagnostics.data_vars[f"gen-{var}"].attrs["units"] == "" assert len(zonal_mean_diagnostics.coords) == 1 assert "lat" in zonal_mean_diagnostics.coords for source in ["target", "prediction"]: histograms = xr.open_dataset(tmp_path / f"histograms_{source}.nc") actual_var_names = sorted([str(k) for k in histograms.keys()]) - assert len(actual_var_names) == 2 - assert "var" in actual_var_names - assert histograms.data_vars["var"].attrs["units"] == "count" - assert "var_bin_edges" in actual_var_names - assert histograms.data_vars["var_bin_edges"].attrs["units"] == "m" - var_counts_per_timestep = histograms["var"].sum(dim=["bin"]) + # NOTE: target histograms include forcing variables + n_vars = ( + len(all_out_names) + if source == "prediction" + else len(all_out_names) + len(forcing_names) + ) + assert len(actual_var_names) == 2 * n_vars + assert var in actual_var_names + assert histograms.data_vars[var].attrs["units"] == "count" + assert f"{var}_bin_edges" in actual_var_names + assert histograms.data_vars[f"{var}_bin_edges"].attrs["units"] == "m" + var_counts_per_timestep = histograms[var].sum(dim=["bin"]) same_count_each_timestep = np.all( var_counts_per_timestep.values == var_counts_per_timestep.values[0] ) assert same_count_each_timestep if monthly_reference_filename is not None: - assert "inference/annual/var" in inference_logs[-1] - assert "inference/annual/r2_gen_var" in inference_logs[-1] - assert "inference/annual/r2_target_var" in inference_logs[-1] + assert f"inference/annual/{var}" in wandb_logs[-1] + assert f"inference/annual/r2_gen_{var}" in wandb_logs[-1] + assert f"inference/annual/r2_target_{var}" in wandb_logs[-1] + assert "inference/total_steps_per_second" in wandb_logs[-1] @pytest.mark.parametrize( @@ -345,14 +384,21 @@ def test_inference_writer_boundaries( out_names = ["var"] all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper" + + horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)] + dim_sizes = DimSizes( n_time=n_forward_steps + 1, - n_lat=4, - n_lon=8, + horizontal=horizontal, nz_interface=4, ) save_plus_one_stepper( - stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d + stepper_path, + in_names, + out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, ) data = FV3GFSData( path=tmp_path, @@ -376,12 +422,17 @@ def test_inference_writer_boundaries( with open(config_filename, "w") as f: yaml.dump(dataclasses.asdict(config), f) with mock_wandb() as wandb: - inference_logs = main( + wandb.configure(log_to_wandb=True) + main( yaml_config=str(config_filename), ) - # initial condition + n_forward_steps autoregressive steps - assert len(inference_logs) == config.n_forward_steps + 1 - assert len(wandb.get_logs()) == len(inference_logs) + inference_logs = wandb.get_logs() + n_ic_timesteps = 1 + summary_log_step = 1 + assert ( + len(inference_logs) + == n_ic_timesteps + config.n_forward_steps + summary_log_step + ) prediction_ds = xr.open_dataset( tmp_path / "autoregressive_predictions.nc", decode_timedelta=False @@ -401,7 +452,6 @@ def test_inference_writer_boundaries( tar["lat"].values, num_lon=len(tar["lon"]) ) # check time mean metrics - assert inference_logs[-1]["inference/mean/forecast_step"] == n_forward_steps tol = 1e-4 # relative tolerance assert metrics.root_mean_squared_error( tar_time_mean, gen_time_mean, area_weights @@ -416,13 +466,13 @@ def test_inference_writer_boundaries( prediction_ds = prediction_ds.isel(sample=0) target_ds = target_ds.isel(sample=0) - ds = xr.open_dataset(data._data_filename) + ds = xr.open_dataset(data.data_filename) for i in range(0, n_forward_steps): # metrics logs includes IC while saved data does not - log = inference_logs[i + 1] + log = inference_logs[i + n_ic_timesteps] # metric steps should match lead times - assert log["inference/mean/forecast_step"] == i + 1 + assert log["inference/mean/forecast_step"] == i + n_ic_timesteps gen_i = torch.from_numpy(gen.isel(time=i).values) tar_i = torch.from_numpy(tar.isel(time=i).values) # check that manually computed metrics match logged metrics @@ -469,14 +519,21 @@ def test_inference_data_time_coarsening(tmp_path: pathlib.Path): out_names = ["var"] all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper" + + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] + dim_sizes = DimSizes( n_time=9, - n_lat=16, - n_lon=32, + horizontal=horizontal, nz_interface=4, ) save_plus_one_stepper( - stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d + stepper_path, + in_names, + out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, ) data = FV3GFSData( path=tmp_path, @@ -542,27 +599,33 @@ def _make_data(): for var in vars } - loss = 42.0 - fake_data = { - k: _make_data() - for k in ("gen_data", "target_data", "gen_data_norm", "target_data_norm") - } - stepped = SteppedData( - loss, - fake_data["gen_data"], - fake_data["target_data"], - fake_data["gen_data_norm"], - fake_data["target_data_norm"], - ) - - sigma_coords = SigmaCoordinates( + vertical_coordinate = HybridSigmaPressureCoordinate( ak=torch.linspace(0, 1, nz + 1, device=get_device()), bk=torch.linspace(0, 1, nz + 1, device=get_device()), ) - derived_stepped = compute_stepped_derived_quantities( - stepped, sigma_coords, TIMESTEP + + def derive_func(data: TensorMapping, forcing_data: TensorMapping) -> TensorDict: + updated = compute_derived_quantities( + dict(data), + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + forcing_data=dict(forcing_data), + ) + return updated + + metrics = {"loss": 42.0} + fake_data = {k: _make_data() for k in ("gen_data", "target_data")} + stepped = TrainOutput( + metrics, + fake_data["gen_data"], + fake_data["target_data"], + time=xr.DataArray(np.zeros((n_sample, n_time)), dims=["sample", "time"]), + normalize=lambda x: x, + derive_func=derive_func, ) + derived_stepped = stepped.compute_derived_variables() + dry_air_name = "surface_pressure_due_to_dry_air" water_path_name = "total_water_path" existence_check = ( @@ -595,14 +658,22 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path): out_names = ["var", "PRESsfc", "specific_total_water_0", "specific_total_water_1"] all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper" + + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] + dim_sizes = DimSizes( n_time=n_forward_steps + 1, - n_lat=16, - n_lon=32, - nz_interface=4, + horizontal=horizontal, + nz_interface=3, ) save_plus_one_stepper( - stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d + stepper_path, + in_names, + out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, + nz_interface=dim_sizes.nz_interface, ) time_varying_values = [float(i) for i in range(dim_sizes.n_time)] data = FV3GFSData( @@ -611,6 +682,7 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path): dim_sizes=dim_sizes, time_varying_values=time_varying_values, timestep_days=TIMESTEP.total_seconds() / 86400, + num_data_workers=2, ) config = InferenceEvaluatorConfig( experiment_dir=str(tmp_path), @@ -630,10 +702,14 @@ def test_derived_metrics_run_without_errors(tmp_path: pathlib.Path): with open(config_filename, "w") as f: yaml.dump(dataclasses.asdict(config), f) - with mock_wandb() as _: - _ = main( - yaml_config=str(config_filename), - ) + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True) + main(yaml_config=str(config_filename)) + inference_logs = wandb.get_logs() + + # derived variables should not have normalized metrics reported + assert "inference/mean_norm/weighted_rmse/total_water_path" not in inference_logs[0] + assert "inference/time_mean_norm/rmse/total_water_path" not in inference_logs[-1] @pytest.mark.parametrize( @@ -678,14 +754,20 @@ def test_inference_ocean_override(tmp_path: pathlib.Path): all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper" n_forward_steps = 8 + + horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)] dim_sizes = DimSizes( n_time=n_forward_steps + 1, - n_lat=4, - n_lon=8, + horizontal=horizontal, nz_interface=4, ) save_plus_one_stepper( - stepper_path, names=all_names, mean=0.0, std=1.0, data_shape=dim_sizes.shape_2d + stepper_path, + in_names, + out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, ) data = FV3GFSData( path=tmp_path, @@ -708,10 +790,7 @@ def test_inference_ocean_override(tmp_path: pathlib.Path): forward_steps_in_memory=4, ocean=ocean_override, ) - stepper = config.load_stepper( - sigma_coordinates=SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)), - area=torch.ones(10), - ) + stepper = config.load_stepper() assert isinstance(stepper.ocean, Ocean) assert ( stepper.ocean.surface_temperature_name @@ -719,6 +798,9 @@ def test_inference_ocean_override(tmp_path: pathlib.Path): ) assert stepper.ocean.ocean_fraction_name == ocean_override.ocean_fraction_name + stepper_config = config.load_stepper_config() + assert stepper_config.ocean == ocean_override + def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path): """Test that inference with a model trained with a different timestep than @@ -726,16 +808,18 @@ def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path): """ in_names = ["var"] out_names = ["var"] - all_names = list(set(in_names).union(out_names)) stepper_path = tmp_path / "stepper_test_data" - dim_sizes = DimSizes(n_time=8, n_lat=4, n_lon=8, nz_interface=2) + + horizontal = [DimSize("grid_yt", 4), DimSize("grid_xt", 8)] + dim_sizes = DimSizes(n_time=8, horizontal=horizontal, nz_interface=2) std = 1.0 save_plus_one_stepper( stepper_path, - names=all_names, + in_names, + out_names, mean=0.0, std=std, - data_shape=dim_sizes.shape_2d, + data_shape=dim_sizes.shape_nd, timestep=TIMESTEP, ) use_prediction_data = False @@ -743,10 +827,70 @@ def test_inference_timestep_mismatch_error(tmp_path: pathlib.Path): with pytest.raises(ValueError, match="Timestep of the loaded stepper"): inference_helper( tmp_path, - all_names, + in_names, + out_names, use_prediction_data, dim_sizes, n_forward_steps, stepper_path, timestep=datetime.timedelta(days=20), ) + + +def test_inference_includes_diagnostics(tmp_path: pathlib.Path): + """Test that diagnostics are included in evaluator metrics and outputs.""" + # NOTE: size of in_names and out_names has to be the same here or the + # PlusOne outputs won't have the right shape + in_names = ["prog", "forcing_var", "DSWRFtoa"] + out_names = ["prog", "ULWRFtoa", "USWRFtoa"] + stepper_path = tmp_path / "stepper" + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] + use_prediction_data = False + n_forward_steps = 2 + dim_sizes = DimSizes( + n_time=n_forward_steps + 1, + horizontal=horizontal, + nz_interface=4, + ) + save_plus_one_stepper( + stepper_path, + in_names, + out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, + timestep=datetime.timedelta(days=20), + ) + inference_helper( + tmp_path, + in_names, + out_names, + use_prediction_data, + dim_sizes, + n_forward_steps, + stepper_path, + save_monthly_files=False, # requires timestep == 6h + timestep=datetime.timedelta(days=20), + derived_names=["net_energy_flux_toa_into_atmosphere"], + ) + ds = xr.open_dataset( + tmp_path / "autoregressive_predictions.nc", + decode_timedelta=False, + decode_times=False, + ) + # prognostic in + assert "prog" in ds + # diags in + assert "ULWRFtoa" in ds + assert "USWRFtoa" in ds + # derived in + assert "net_energy_flux_toa_into_atmosphere" in ds + # forcings not in + assert "DSWRFtoa" not in ds + assert "forcing_var" not in ds + # assert only prognostic variables are in initial condition and restart files + for filename in ["initial_condition.nc", "restart.nc"]: + ds = xr.open_dataset(tmp_path / filename) + assert "USWRFtoa" not in ds + assert "forcing_var" not in ds + assert "prog" in ds diff --git a/fme/fme/ace/inference/test_inference.py b/fme/fme/ace/inference/test_inference.py index 66c4bc1..cfabd18 100644 --- a/fme/fme/ace/inference/test_inference.py +++ b/fme/fme/ace/inference/test_inference.py @@ -3,7 +3,7 @@ import dataclasses import datetime import pathlib -from typing import List, Tuple +from typing import List import cftime import numpy as np @@ -13,6 +13,13 @@ import yaml import fme +from fme.ace.data_loading.batch_data import PrognosticState +from fme.ace.data_loading.inference import ( + ExplicitIndices, + ForcingDataLoaderConfig, + InferenceInitialConditionIndices, + TimestampList, +) from fme.ace.inference.data_writer import DataWriterConfig from fme.ace.inference.inference import ( InferenceConfig, @@ -21,17 +28,17 @@ main, ) from fme.ace.registry import ModuleSelector -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.data_loading.inference import ( - ExplicitIndices, - ForcingDataLoaderConfig, - InferenceInitialConditionIndices, - TimestampList, +from fme.ace.stepper import SingleModuleStepperConfig +from fme.ace.testing import DimSizes, FV3GFSData +from fme.core.coordinates import ( + DimSize, + HybridSigmaPressureCoordinate, + LatLonCoordinates, ) +from fme.core.gridded_ops import LatLonOperations from fme.core.logging_utils import LoggingConfig -from fme.core.normalizer import FromStateNormalizer -from fme.core.stepper import SingleModuleStepperConfig -from fme.core.testing import DimSizes, FV3GFSData +from fme.core.normalizer import NormalizationConfig +from fme.core.testing import mock_wandb TIMESTEP = datetime.timedelta(hours=6) @@ -47,7 +54,7 @@ def save_stepper( out_names: List[str], mean: float, std: float, - data_shape: Tuple[int, int, int], + data_shape: List[int], timestep: datetime.timedelta = TIMESTEP, ): all_names = list(set(in_names).union(out_names)) @@ -55,19 +62,19 @@ def save_stepper( builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}), in_names=in_names, out_names=out_names, - normalization=FromStateNormalizer( - state={ - "means": {name: mean for name in all_names}, - "stds": {name: std for name in all_names}, - } + normalization=NormalizationConfig( + means={name: mean for name in all_names}, + stds={name: std for name in all_names}, ), ) area = torch.ones(data_shape[-2:], device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) stepper = config.get_stepper( - img_shape=data_shape[-2:], - area=area, - sigma_coordinates=sigma_coordinates, + img_shape=(data_shape[-2], data_shape[-1]), + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=timestep, ) torch.save({"stepper": stepper.get_state()}, path) @@ -75,13 +82,16 @@ def save_stepper( def test_inference_entrypoint(tmp_path: pathlib.Path): forward_steps_in_memory = 2 - in_names = ["prog", "forcing_var"] - out_names = ["prog", "diagnostic_var"] + # NOTE: number of inputs and outputs has to be the same for the PlusOne + # stepper module to work properly + in_names = ["prog", "forcing_var", "DSWRFtoa"] + out_names = ["prog", "ULWRFtoa", "USWRFtoa"] stepper_path = tmp_path / "stepper" + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] + dim_sizes = DimSizes( n_time=9, - n_lat=16, - n_lon=32, + horizontal=horizontal, nz_interface=4, ) save_stepper( @@ -90,23 +100,37 @@ def test_inference_entrypoint(tmp_path: pathlib.Path): out_names=out_names, mean=0.0, std=1.0, - data_shape=dim_sizes.shape_2d, + data_shape=dim_sizes.shape_nd, ) data = FV3GFSData( path=tmp_path, - names=["forcing_var"], + names=["forcing_var", "DSWRFtoa"], dim_sizes=dim_sizes, timestep_days=0.25, + save_vertical_coordinate=False, ) initial_condition = xr.Dataset( - {"prog": xr.DataArray(np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"])} + { + "prog": xr.DataArray( + np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"] + ), + "forcing": xr.DataArray( + np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"] + ), + "DSWRFtoa": xr.DataArray( + np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"] + ), + } ) initial_condition_path = tmp_path / "init_data" / "ic.nc" initial_condition_path.parent.mkdir() initial_condition["time"] = xr.DataArray( - [cftime.datetime(2000, 1, 1, 6), cftime.datetime(2000, 1, 1, 18)], - dims=["sample"], + [ + cftime.DatetimeProlepticGregorian(2000, 1, 1, 6), + cftime.DatetimeProlepticGregorian(2000, 1, 1, 18), + ], + dims=["time"], ) initial_condition.to_netcdf(initial_condition_path, mode="w") forcing_loader = ForcingDataLoaderConfig( @@ -122,7 +146,7 @@ def test_inference_entrypoint(tmp_path: pathlib.Path): logging=LoggingConfig( log_to_screen=True, log_to_file=False, - log_to_wandb=False, + log_to_wandb=True, ), initial_condition=InitialConditionConfig(path=str(initial_condition_path)), forcing_loader=forcing_loader, @@ -131,9 +155,39 @@ def test_inference_entrypoint(tmp_path: pathlib.Path): config_filename = tmp_path / "config.yaml" with open(config_filename, "w") as f: yaml.dump(dataclasses.asdict(config), f) - main(yaml_config=str(config_filename)) + + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True) + main(yaml_config=str(config_filename)) + wandb_logs = wandb.get_logs() + + n_ic_timesteps = 1 + summary_log_step = 1 + assert len(wandb_logs) == n_ic_timesteps + config.n_forward_steps + summary_log_step + for i, log in enumerate(wandb_logs): + for metric, val in log.items(): + # check that time series metrics match + if "inference/mean" in metric: + if i > 0: + assert metric in wandb_logs[i] + if np.isnan(val): + assert np.isnan(wandb_logs[i][metric]) + else: + assert wandb_logs[i][metric] == val + elif not np.isnan(val): # for IC only valid data is reported to wandb + assert metric in wandb_logs[i] + assert wandb_logs[i][metric] == val + ds = xr.open_dataset(tmp_path / "autoregressive_predictions.nc") + # prognostic in assert "prog" in ds + # diags in + assert "ULWRFtoa" in ds + assert "USWRFtoa" in ds + # derived in + assert "net_energy_flux_toa_into_atmosphere" in ds + # forcings not in + assert "DSWRFtoa" not in ds assert "forcing_var" not in ds assert ds["prog"].sizes == {"time": 4, "sample": 2, "lat": 16, "lon": 32} np.testing.assert_allclose( @@ -142,6 +196,26 @@ def test_inference_entrypoint(tmp_path: pathlib.Path): np.testing.assert_allclose( ds["prog"].isel(time=1).values, ds["prog"].isel(time=0).values + 1, rtol=1e-6 ) + saved_data = xr.open_dataset(data.data_filename) + ops = LatLonCoordinates( + lat=torch.as_tensor(saved_data["grid_yt"].values), + lon=torch.as_tensor(saved_data["grid_xt"].values), + ).gridded_operations + # check that inference logs match raw output + for i in range(1, config.n_forward_steps + 1): + for log_name in wandb_logs[i]: + if "inference/mean/weighted_mean_gen" in log_name: + variable_name = log_name.split("/")[-1] + # note raw output does not include initial condition, hence + # i-1 below. Code uses area from data, not stepper above. + raw_variable = ds[variable_name].isel(time=i - 1) + raw_global_mean = ops.area_weighted_mean( + torch.as_tensor(raw_variable.values) + ).mean() + np.testing.assert_allclose( + raw_global_mean, wandb_logs[i][log_name], rtol=1e-6 + ) + assert "inference/total_steps_per_second" in wandb_logs[-1] def test_get_initial_condition(): @@ -150,13 +224,19 @@ def test_get_initial_condition(): np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"] ) data = xr.Dataset({"prog": prognostic_da, "time": time_da}) - initial_condition, initial_times = get_initial_condition(data, ["prog"]) + initial_condition = get_initial_condition(data, ["prog"]) + assert isinstance(initial_condition, PrognosticState) + batch_data = initial_condition.as_batch_data() + assert batch_data.time.shape == (2, 1) + initial_times = batch_data.time.isel(time=0) assert initial_times.shape == (2,) assert initial_times[0] == 0 assert initial_times[1] == 5 - assert initial_condition["prog"].shape == (2, 16, 32) - np.testing.assert_allclose(initial_condition["prog"].numpy(), data["prog"].values) - assert initial_condition["prog"].device == fme.get_device() + assert batch_data.data["prog"].shape == (2, 1, 16, 32) + np.testing.assert_allclose( + batch_data.data["prog"].squeeze(dim=1).cpu().numpy(), data["prog"].values + ) + assert batch_data.time.isel(time=0).equals(initial_times) def test_get_initial_condition_raises_bad_variable_shape(): diff --git a/fme/fme/ace/inference/test_segmented.py b/fme/fme/ace/inference/test_segmented.py new file mode 100644 index 0000000..bc552d7 --- /dev/null +++ b/fme/fme/ace/inference/test_segmented.py @@ -0,0 +1,244 @@ +"""Tests for segmented inference entrypoint.""" + +import dataclasses +import datetime +import os +import pathlib +import tempfile +import unittest.mock +from typing import List + +import cftime +import numpy as np +import pytest +import torch +import xarray as xr +import yaml + +import fme +from fme.ace.data_loading.inference import ForcingDataLoaderConfig, TimestampList +from fme.ace.inference.data_writer import DataWriterConfig +from fme.ace.inference.inference import ( + InitialConditionConfig, + main, + run_segmented_inference, +) +from fme.ace.registry import ModuleSelector +from fme.ace.stepper import SingleModuleStepperConfig +from fme.ace.testing import DimSizes, FV3GFSData +from fme.core.coordinates import DimSize, HybridSigmaPressureCoordinate +from fme.core.dataset.config import XarrayDataConfig +from fme.core.gridded_ops import LatLonOperations +from fme.core.logging_utils import LoggingConfig +from fme.core.normalizer import NormalizationConfig + +TIMESTEP = datetime.timedelta(hours=6) + + +class PlusOne(torch.nn.Module): + def forward(self, x): + return x + 1 + + +def save_stepper( + path: pathlib.Path, + in_names: List[str], + out_names: List[str], + mean: float, + std: float, + data_shape: List[int], + timestep: datetime.timedelta = TIMESTEP, +): + all_names = list(set(in_names).union(out_names)) + config = SingleModuleStepperConfig( + builder=ModuleSelector(type="prebuilt", config={"module": PlusOne()}), + in_names=in_names, + out_names=out_names, + normalization=NormalizationConfig( + means={name: mean for name in all_names}, + stds={name: std for name in all_names}, + ), + ) + area = torch.ones(data_shape[-2:], device=fme.get_device()) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + stepper = config.get_stepper( + img_shape=(data_shape[-2], data_shape[-1]), + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, + timestep=timestep, + ) + torch.save({"stepper": stepper.get_state()}, path) + + +def test_inference_segmented_entrypoint(): + # we use tempfile here instead of pytest tmp_path fixture, because the latter causes + # issues with checking last modified time of files produced by the test. + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = pathlib.Path(tmp_dir) + forward_steps_in_memory = 2 + in_names = ["prog", "forcing_var"] + out_names = ["prog", "diagnostic_var"] + stepper_path = tmp_path / "stepper" + horizontal = [DimSize("grid_yt", 16), DimSize("grid_xt", 32)] + + dim_sizes = DimSizes( + n_time=18, + horizontal=horizontal, + nz_interface=4, + ) + save_stepper( + stepper_path, + in_names=in_names, + out_names=out_names, + mean=0.0, + std=1.0, + data_shape=dim_sizes.shape_nd, + ) + data = FV3GFSData( + path=tmp_path, + names=["forcing_var"], + dim_sizes=dim_sizes, + timestep_days=0.25, + ) + initial_condition = xr.Dataset( + { + "prog": xr.DataArray( + np.random.rand(2, 16, 32), dims=["sample", "lat", "lon"] + ) + } + ) + + initial_condition_path = tmp_path / "init_data" / "ic.nc" + initial_condition_path.parent.mkdir() + initial_condition["time"] = xr.DataArray( + [cftime.datetime(2000, 1, 1, 6), cftime.datetime(2000, 1, 1, 18)], + dims=["sample"], + ) + initial_condition.to_netcdf(initial_condition_path, mode="w") + forcing_loader = ForcingDataLoaderConfig( + dataset=data.inference_data_loader_config.dataset, + num_data_workers=0, + ) + + run_dir = tmp_path / "segmented_run" + config = fme.ace.InferenceConfig( + experiment_dir=str(run_dir), + n_forward_steps=3, + forward_steps_in_memory=forward_steps_in_memory, + checkpoint_path=str(stepper_path), + logging=LoggingConfig( + log_to_screen=True, log_to_file=False, log_to_wandb=False + ), + initial_condition=InitialConditionConfig( + path=str(initial_condition_path), + start_indices=TimestampList(["2000-01-01T06:00:00"]), + ), + forcing_loader=forcing_loader, + data_writer=DataWriterConfig(save_prediction_files=True), + ) + + # run one segment of 3 steps + config_path = str(tmp_path / "config.yaml") + with open(config_path, "w") as f: + yaml.dump(dataclasses.asdict(config), f) + main(config_path, 1) + + # run another segment of 3 steps, and ensure first segment is not being re-run + filename = os.path.join( + run_dir, "segment_0000", "autoregressive_predictions.nc" + ) + before_second_segment_mtime = os.path.getmtime(filename) + main(config_path, 2) + after_second_segment_mtime = os.path.getmtime(filename) + assert before_second_segment_mtime == pytest.approx(after_second_segment_mtime) + + # do a non-segmented run of 6 steps + config.n_forward_steps = 6 + config.experiment_dir = str(tmp_path / "non_segmented_run") + config_path = str(tmp_path / "config.yaml") + with open(config_path, "w") as f: + yaml.dump(dataclasses.asdict(config), f) + main(config_path) + + # assert each segment generated output of correct duration + ds_two_segments_0 = xr.open_dataset( + run_dir / "segment_0000" / "autoregressive_predictions.nc" + ) + ds_two_segments_1 = xr.open_dataset( + run_dir / "segment_0001" / "autoregressive_predictions.nc" + ) + assert len(ds_two_segments_0.time) == len(ds_two_segments_1.time) + + # Ensure the second half of the 6-step run matches the second segment of the + # 3-step run. Before comparing, drop init_time and time coordinates, since + # we don't expect these to match. + ds_one_segment = xr.open_dataset( + tmp_path / "non_segmented_run" / "autoregressive_predictions.nc" + ) + ds_two_segments_1 = ds_two_segments_1.drop_vars(["init_time", "time"]) + ds_one_segment = ds_one_segment.drop_vars(["init_time", "time"]) + xr.testing.assert_identical( + ds_two_segments_1, ds_one_segment.isel(time=slice(3, None)) + ) + + +def _run_inference_from_config_mock(config: fme.ace.InferenceConfig): + if not os.path.exists(config.experiment_dir): + os.makedirs(config.experiment_dir) + with open(os.path.join(config.experiment_dir, "restart.nc"), "w") as f: + f.write("mock restart file") + with open(os.path.join(config.experiment_dir, "wandb_name_env_var"), "w") as f: + f.write(os.environ.get("WANDB_NAME", "")) + + +def _get_mock_config(experiment_dir: str) -> fme.ace.InferenceConfig: + return fme.ace.InferenceConfig( + experiment_dir=experiment_dir, + n_forward_steps=3, + checkpoint_path="mock_checkpoint", + logging=LoggingConfig( + log_to_screen=True, log_to_file=False, log_to_wandb=False + ), + initial_condition=InitialConditionConfig(path="mock_ic"), + forcing_loader=ForcingDataLoaderConfig( + dataset=XarrayDataConfig(data_path="mock_forcing") + ), + ) + + +def test_run_segmented_inference(tmp_path, monkeypatch): + WRITTEN_WANDB_NAME_FILENAME = "wandb_name_env_var" + mock = unittest.mock.MagicMock(side_effect=_run_inference_from_config_mock) + config = _get_mock_config(str(tmp_path)) + + with unittest.mock.patch( + "fme.ace.inference.inference.run_inference_from_config", new=mock + ): + # run a single segment + monkeypatch.setenv("WANDB_NAME", "run_name") + run_segmented_inference(config, 1) + segment_dir = os.path.join(config.experiment_dir, "segment_0000") + expected_restart_path = os.path.join(segment_dir, "restart.nc") + assert os.path.exists(expected_restart_path) + assert mock.call_count == 1 + with open(os.path.join(segment_dir, WRITTEN_WANDB_NAME_FILENAME)) as f: + assert f.read() == "run_name-segment_0000" + + # rerun the same segment and ensure run_inference_from_config isn't called again + run_segmented_inference(config, 1) + assert os.path.exists(expected_restart_path) + assert mock.call_count == 1 + + # extend to three segments and ensure exactly three run_inference_from_config + # calls have been made + monkeypatch.setenv("WANDB_NAME", "run_name") + run_segmented_inference(config, 3) + for i in range(3): + segment_dir = os.path.join(config.experiment_dir, f"segment_{i:04d}") + expected_restart_path = os.path.join(segment_dir, "restart.nc") + assert os.path.exists(expected_restart_path) + with open(os.path.join(segment_dir, WRITTEN_WANDB_NAME_FILENAME)) as f: + assert f.read() == f"run_name-segment_{i:04d}" + assert mock.call_count == 3 diff --git a/fme/fme/ace/models/healpix/__init__.py b/fme/fme/ace/models/healpix/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fme/fme/ace/models/healpix/healpix_activations.py b/fme/fme/ace/models/healpix/healpix_activations.py new file mode 100644 index 0000000..336cc7a --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_activations.py @@ -0,0 +1,206 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from typing import Literal + +import torch as th +import torch.nn as nn + +from .healpix_layers import HEALPixLayer + +# DOWNSAMPLING BLOCKS + + +class MaxPool(nn.Module): + """Wrapper for applying Max Pooling with HEALPix or other tensor data. + + This class wraps the `nn.MaxPool2d` class to handle tensor data with + HEALPix or other geometry layers. + """ + + def __init__( + self, + pooling: int = 2, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Args: + pooling (int, optional): Pooling kernel size passed to geometry layer. + enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper. + enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper. + """ + super().__init__() + self.maxpool = HEALPixLayer( + layer=nn.MaxPool2d, + kernel_size=pooling, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Forward pass of the MaxPool. + + Args: + x: The values to MaxPool. + + Returns: + The MaxPooled values. + """ + return self.maxpool(x) + + +class AvgPool(nn.Module): + """Wrapper for applying Average Pooling with HEALPix or other tensor data. + + This class wraps the `nn.AvgPool2d` class to handle tensor data with + HEALPix or other geometry layers. + """ + + def __init__( + self, + pooling: int = 2, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Args: + pooling (int, optional): Pooling kernel size passed to geometry layer. + enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper. + enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper. + """ + super().__init__() + self.avgpool = HEALPixLayer( + layer=nn.AvgPool2d, + kernel_size=pooling, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Forward pass of the AvgPool layer. + + Args: + x: The values to average. + + Returns: + The averaged values. + """ + return self.avgpool(x) + + +@dataclasses.dataclass +class DownsamplingBlockConfig: + """ + Configuration for the downsampling block. + Generally, either a pooling block or a striding conv block. + + Parameters: + block_type: Type of recurrent block, either "MaxPool" or "AvgPool" + pooling: Pooling size + enable_nhwc: Flag to enable NHWC data format, default is False. + enable_healpixpad: Flag to enable HEALPix padding, default is False. + + """ + + block_type: Literal["MaxPool", "AvgPool"] + pooling: int = 2 + enable_nhwc: bool = False + enable_healpixpad: bool = False + + def build(self) -> nn.Module: + """ + Builds the recurrent block model. + + Returns: + Recurrent block. + """ + if self.block_type == "MaxPool": + return MaxPool( + pooling=self.pooling, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + + elif self.block_type == "AvgPool": + return AvgPool( + pooling=self.pooling, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + else: + raise ValueError(f"Unsupported block type: {self.block_type}") + + +@dataclasses.dataclass +class CappedGELUConfig: + """ + Configuration for the CappedGELU activation function. + + Parameters: + cap_value: Cap value for the GELU function, default is 10. + enable_nhwc: Flag to enable NHWC data format, default is False. + enable_healpixpad: Flag to enable HEALPix padding, default is False. + """ + + cap_value: int = 10 + enable_nhwc: bool = False + enable_healpixpad: bool = False + + def build(self) -> nn.Module: + """ + Builds the CappedGELU activation function. + + Returns: + CappedGELU activation function. + """ + return CappedGELU(cap_value=self.cap_value) + + +class CappedGELU(nn.Module): + """ + Implements a GELU with capped maximum value. + + Example + ------- + >>> capped_gelu_func = modulus.models.layers.CappedGELU() + >>> input = th.Tensor([[-2,-1],[0,1],[2,3]]) + >>> capped_gelu_func(input) + tensor([[-0.0455, -0.1587], + [ 0.0000, 0.8413], + [ 1.0000, 1.0000]]) + + """ + + def __init__(self, cap_value=1.0, **kwargs): + """ + Args: + cap_value: Maximum that values will be capped at + **kwargs: Keyword arguments to be passed to the `th.nn.GELU` function + """ + + super().__init__() + self.add_module("gelu", th.nn.GELU(**kwargs)) + self.register_buffer("cap", th.tensor(cap_value, dtype=th.float32)) + + def forward(self, inputs): + x = self.gelu(inputs) + # Convert cap to a scalar value for clamping (ignores grad) + cap_value = self.cap.item() + x = th.clamp(x, max=cap_value) + return x diff --git a/fme/fme/ace/models/healpix/healpix_blocks.py b/fme/fme/ace/models/healpix/healpix_blocks.py new file mode 100644 index 0000000..e1c667b --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_blocks.py @@ -0,0 +1,926 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import dataclasses +from typing import Literal, Optional, Tuple, Union, cast + +import torch as th +import torch.nn as nn + +from .healpix_activations import CappedGELUConfig +from .healpix_layers import HEALPixLayer + +# RECURRENT BLOCKS + + +@dataclasses.dataclass +class RecurrentBlockConfig: + """ + Configuration for the recurrent block. + + Parameters: + in_channels: Number of input channels, default is 3. + kernel_size: Size of the kernel, default is 1. + enable_nhwc: Flag to enable NHWC data format, default is False. + enable_healpixpad: Flag to enable HEALPix padding, default is False. + block_type: Type of recurrent block, either "ConvGRUBlock" or "ConvLSTMBlock", + default is "ConvGRUBlock". + """ + + in_channels: int = 3 + kernel_size: int = 1 + enable_nhwc: bool = False + enable_healpixpad: bool = False + block_type: Literal["ConvGRUBlock", "ConvLSTMBlock"] = "ConvGRUBlock" + + def build(self) -> nn.Module: + """ + Builds the recurrent block model. + + Returns: + Recurrent block. + """ + if self.block_type == "ConvGRUBlock": + return ConvGRUBlock( + in_channels=self.in_channels, + kernel_size=self.kernel_size, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + elif self.block_type == "ConvLSTMBlock": + return ConvLSTMBlock( + in_channels=self.in_channels, + kernel_size=self.kernel_size, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + else: + raise ValueError(f"Unsupported block type: {self.block_type}") + + +@dataclasses.dataclass +class ConvBlockConfig: + """ + Configuration for the convolutional block. + + Parameters: + in_channels: Number of input channels, default is 3. + out_channels: Number of output channels, default is 1. + kernel_size: Size of the kernel, default is 3. + dilation: Dilation rate, default is 1. + n_layers: Number of layers, default is 1. + upsampling: Upsampling factor for TransposedConvUpsample, default is 2. + upscale_factor: Upscale factor for ConvNeXtBlock and SymmetricConvNeXtBlock, + default is 4. + latent_channels: Number of latent channels, default is None. + activation: Activation configuration, default is None. + enable_nhwc: Flag to enable NHWC data format, default is False. + enable_healpixpad: Flag to enable HEALPix padding, default is False. + block_type: Type of block, default is "BasicConvBlock". + """ + + in_channels: int = 3 + out_channels: int = 1 + kernel_size: int = 3 + dilation: int = 1 + n_layers: int = 1 + upsampling: int = 2 + upscale_factor: int = 4 + latent_channels: Optional[int] = None + activation: Optional[CappedGELUConfig] = None + enable_nhwc: bool = False + enable_healpixpad: bool = False + block_type: Literal[ + "BasicConvBlock", + "ConvNeXtBlock", + "SymmetricConvNeXtBlock", + "TransposedConvUpsample", + ] = "BasicConvBlock" + + def build(self) -> nn.Module: + """ + Builds the convolutional block model. + + Returns: + Convolutional block model. + """ + if self.block_type == "BasicConvBlock": + return BasicConvBlock( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + dilation=self.dilation, + n_layers=self.n_layers, + latent_channels=self.latent_channels, + activation=self.activation, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + elif self.block_type == "ConvNeXtBlock": + if self.latent_channels is None: + self.latent_channels = 1 + return ConvNeXtBlock( + in_channels=self.in_channels, + latent_channels=cast(int, self.latent_channels), + out_channels=self.out_channels, + kernel_size=self.kernel_size, + dilation=self.dilation, + upscale_factor=self.upscale_factor, + activation=self.activation, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + elif self.block_type == "SymmetricConvNeXtBlock": + if self.latent_channels is None: + self.latent_channels = 1 + return SymmetricConvNeXtBlock( + in_channels=self.in_channels, + latent_channels=cast(int, self.latent_channels), + out_channels=self.out_channels, + kernel_size=self.kernel_size, + dilation=self.dilation, + upscale_factor=self.upscale_factor, + activation=self.activation, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + elif self.block_type == "TransposedConvUpsample": + return TransposedConvUpsample( + in_channels=self.in_channels, + out_channels=self.out_channels, + upsampling=self.upsampling, + activation=self.activation, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + else: + raise ValueError(f"Unsupported block type: {self.block_type}") + + +class ConvGRUBlock(nn.Module): + """Class that implements a Convolutional GRU. + + Code modified from: + https://github.com/happyjin/ConvGRU-pytorch/blob/master/convGRU.py + """ + + def __init__( + self, + in_channels=3, + kernel_size=1, + enable_nhwc=False, + enable_healpixpad=False, + ): + """ + Args: + in_channels: The number of input channels. + kernel_size: Size of the convolutional kernel. + enable_nhwc: Enable nhwc format, passed to wrapper. + enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper. + """ + super().__init__() + + self.channels = in_channels + self.conv_gates = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels + self.channels, + out_channels=2 * self.channels, # for update_gate, reset_gate respectively + kernel_size=kernel_size, + padding="same", + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + self.conv_can = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels + self.channels, + out_channels=self.channels, # for candidate neural memory + kernel_size=kernel_size, + padding="same", + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + self.h = th.zeros(1, 1, 1, 1) + + def forward(self, inputs): + """Forward pass of the ConvGRUBlock. + + Args: + inputs: Input to the forward pass. + + Returns: + th.Tensor: Result of the forward pass. + """ + if inputs.shape != self.h.shape: + self.h = th.zeros_like(inputs) + combined = th.cat([inputs, self.h], dim=1) + combined_conv = self.conv_gates(combined) + + gamma, beta = th.split(combined_conv, self.channels, dim=1) + reset_gate = th.sigmoid(gamma) + update_gate = th.sigmoid(beta) + + combined = th.cat([inputs, reset_gate * self.h], dim=1) + cc_cnm = self.conv_can(combined) + cnm = th.tanh(cc_cnm) + + h_next = (1 - update_gate) * self.h + update_gate * cnm + self.h = h_next + + return inputs + h_next + + def reset(self): + """Reset the update gates.""" + self.h = th.zeros_like(self.h) + + +class ConvLSTMBlock(nn.Module): + """Convolutional LSTM block.""" + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 1, + latent_channels: int = 1, + kernel_size: int = 3, + downscale_factor: int = 4, + upscale_factor: int = 4, + n_layers: int = 1, + latent_conv_size: int = 3, # Add latent_conv_size parameter + dilation: int = 1, + activation: nn.Module = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Args: + in_channels: The number of input channels. + out_channels: The number of output channels. + latent_channels: Number of latent channels. + kernel_size: Size of the convolutional kernel. + downscale_factor: Downscale factor. + upscale_factor: Upscale factor. + n_layers: Number of layers. + latent_conv_size: Size of latent convolution. + dilation: Spacing between kernel points. + activation: Activation function. + enable_nhwc: Enable nhwc format. + enable_healpixpad: If HEALPixPadding should be enabled. + """ + super().__init__() + # Instantiate 1x1 conv to increase/decrease channel depth if necessary + # Skip connection for output + if in_channels == out_channels: + self.skip_module = lambda x: x # Identity-function required in forward pass + else: + self.skip_module = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=in_channels, # out channels describes the space of the output of conv here; but we have the output of LSTM which is the input layer size + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + # Convolution block + convblock = [] + # 3x3 convolution increasing channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels + * 2, # accounts for the h layer, which is concatenated before convolution runs + out_channels=int(latent_channels * upscale_factor), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation) + # 3x3 convolution maintaining increased channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=int(latent_channels * upscale_factor), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation) + + # Now for the LSTM bit + self.channels = in_channels + self.lstm_gates = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=latent_channels * upscale_factor, + out_channels=self.channels + * 4, # for input_gate, forget_gate, cell_gate, output_gate respectively (LSTM) + kernel_size=kernel_size, + padding="same", + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + self.h = th.zeros(1, 1, 1, 1) + self.c = th.zeros(1, 1, 1, 1) + self.convblock = nn.Sequential(*convblock) + + def forward(self, inputs): + """Forward pass of the ConvLSTMBlock. + + Args: + x: Inputs to the forward pass. + + Returns: + th.Tensor: Result of the forward pass. + """ + if inputs.shape != self.h.shape: + self.h = th.zeros_like(inputs) + self.c = th.zeros_like(inputs) + + combined = th.cat([inputs, self.h], dim=1) + conv_outputs = self.convblock(combined) + + lstm_gates = self.lstm_gates(conv_outputs) + + # Split the combined_conv into input_gate, forget_gate, cell_gate, output_gate + i, f, c_hat, o = th.split(lstm_gates, self.channels, dim=1) + input_gate = th.sigmoid(i) + forget_gate = th.sigmoid(f) + cell_gate = th.tanh(c_hat) + output_gate = th.sigmoid(o) + + self.c = forget_gate * self.c + input_gate * cell_gate + self.h = output_gate * th.tanh(self.c) + + skip_connection = self.skip_module(inputs) + return skip_connection + self.h + + def reset(self): + self.h = th.zeros_like(self.h) + self.c = th.zeros_like(self.c) + + +# CONV BLOCKS + + +class BasicConvBlock(nn.Module): + """Convolution block consisting of n subsequent convolutions and activations.""" + + def __init__( + self, + in_channels=3, + out_channels=1, + kernel_size=3, + dilation=1, + n_layers=1, + latent_channels=None, + activation=None, + enable_nhwc=False, + enable_healpixpad=False, + ): + """ + Args: + in_channels: The number of input channels. + out_channels: The number of output channels. + kernel_size: Size of the convolutional kernel. + dilation: Spacing between kernel points, passed to nn.Conv2d. + n_layers: Number of convolutional layers. + latent_channels: Number of latent channels. + activation: ModuleConfig for activation function to use. + enable_nhwc: Enable nhwc format, passed to wrapper. + enable_healpixpad:: If HEALPixPadding should be enabled, passed to wrapper. + """ + super().__init__() + if latent_channels is None: + latent_channels = max(in_channels, out_channels) + convblock = [] + for n in range(n_layers): + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels if n == 0 else latent_channels, + out_channels=out_channels if n == n_layers - 1 else latent_channels, + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + self.convblock = nn.Sequential(*convblock) + + def forward(self, x): + """Forward pass of the BasicConvBlock. + + Args: + x: Inputs to the forward pass. + + Returns: + th.Tensor: Result of the forward pass. + """ + return self.convblock(x) + + +class ConvNeXtBlock(nn.Module): + """A modified ConvNeXt network block as described in the paper + "A ConvNet for the 21st Century" (https://arxiv.org/pdf/2201.03545.pdf). + + This block consists of a series of convolutional layers with optional activation functions, + and a residual connection. + + Parameters: + skip_module: A module to align the input and output channels for the residual connection. + convblock: A sequential container of convolutional layers with optional activation functions. + """ + + def __init__( + self, + in_channels: int = 3, + latent_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + dilation: int = 1, + upscale_factor: int = 4, + activation: Optional[CappedGELUConfig] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Initializes a ConvNeXtBlock instance with specified parameters. + + Args: + in_channels: Number of input channels. + latent_channels: Number of latent channels used in the block. + out_channels: Number of output channels. + kernel_size: Size of the convolutional kernels. + dilation: Dilation rate for convolutions. + upscale_factor: Factor by which to upscale the number of latent channels. + activation: Configuration for the activation function used between layers. + enable_nhwc: Whether to enable NHWC format. + enable_healpixpad: Whether to enable HEALPixPadding. + """ + super().__init__() + + # Instantiate 1x1 conv to increase/decrease channel depth if necessary + if in_channels == out_channels: + self.skip_module = lambda x: x # Identity-function required in forward pass + else: + self.skip_module = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + # Convolution block + convblock = [] + # 3x3 convolution increasing channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=int(latent_channels * upscale_factor), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + # 3x3 convolution maintaining increased channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=int(latent_channels * upscale_factor), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + # Linear postprocessing + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=out_channels, + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + self.convblock = nn.Sequential(*convblock) + + def forward(self, x): + """Forward pass of the ConvNeXtBlock. + + Args: + x: Input tensor. + + Returns: + The result of the forward pass. + """ + return self.skip_module(x) + self.convblock(x) + + +class DoubleConvNeXtBlock(nn.Module): + """A variant of the ConvNeXt block that includes two sequential ConvNeXt blocks within a single module. + + Parameters: + skip_module1: A module to align the input and intermediate channels for the first residual connection. + skip_module2: A module to align the intermediate and output channels for the second residual connection. + convblock1: A sequential container of convolutional layers for the first ConvNeXt block. + convblock2: A sequential container of convolutional layers for the second ConvNeXt block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 1, + kernel_size: int = 3, + dilation: int = 1, + upscale_factor: int = 4, + latent_channels: int = 1, + activation: Optional[CappedGELUConfig] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Initializes a DoubleConvNeXtBlock instance with specified parameters. + + Args: + in_channels: Number of input channels (default is 3). + out_channels: Number of output channels (default is 1). + kernel_size: Size of the convolutional kernels (default is 3). + dilation: Dilation rate for convolutions (default is 1). + upscale_factor: Factor by which to upscale the number of latent channels (default is 4). + latent_channels: Number of latent channels used in the block (default is 1). + activation: Configuration for the activation function used between layers (default is None). + enable_nhwc: Whether to enable NHWC format (default is False). + enable_healpixpad: Whether to enable HEALPixPadding (default is False). + """ + super().__init__() + + if in_channels == int(latent_channels): + self.skip_module1 = ( + lambda x: x + ) # Identity-function required in forward pass + else: + self.skip_module1 = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=int(latent_channels), + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + if out_channels == int(latent_channels): + self.skip_module2 = ( + lambda x: x + ) # Identity-function required in forward pass + else: + self.skip_module2 = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=out_channels, + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + + # 1st ConvNeXt block, the output of this one remains internal + convblock1 = [] + # 3x3 convolution establishing latent channels channels + convblock1.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=int(latent_channels), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock1.append(activation.build()) + # 1x1 convolution establishing increased channels + convblock1.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=int(latent_channels * upscale_factor), + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock1.append(activation.build()) + # 1x1 convolution returning to latent channels + convblock1.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=int(latent_channels), + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock1.append(activation.build()) + self.convblock1 = nn.Sequential(*convblock1) + + # 2nd ConNeXt block, takes the output of the first convnext block + convblock2 = [] + # 3x3 convolution establishing latent channels channels + convblock2.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=int(latent_channels), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock2.append(activation.build()) + # 1x1 convolution establishing increased channels + convblock2.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=int(latent_channels * upscale_factor), + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock2.append(activation.build()) + # 1x1 convolution reducing to output channels + convblock2.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=out_channels, + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock2.append(activation.build()) + self.convblock2 = nn.Sequential(*convblock2) + + def forward(self, x): + """Forward pass of the DoubleConvNextBlock + Args: + x: inputs to the forward pass + Returns: + result of the forward pass + """ + # internal convnext result + x1 = self.skip_module1(x) + self.convblock1(x) + # return second convnext result + return self.skip_module2(x1) + self.convblock2(x1) + + +class SymmetricConvNeXtBlock(nn.Module): + """A symmetric variant of the ConvNeXt block, with convolutional layers mirrored + around a central axis for symmetric feature extraction. + + Parameters: + skip_module1: A module to align the input and intermediate channels for the first residual connection. + skip_module2: A module to align the intermediate and output channels for the second residual connection. + convblock1: A sequential container of convolutional layers for the symmetric ConvNeXt block. + """ + + def __init__( + self, + in_channels: int = 3, + latent_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + dilation: int = 1, + upscale_factor: int = 4, + activation: Optional[CappedGELUConfig] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Initializes a SymmetricConvNeXtBlock instance with specified parameters. + + Args: + in_channels: Number of input channels (default is 3). + out_channels: Number of output channels (default is 1). + kernel_size: Size of the convolutional kernels (default is 3). + dilation: Dilation rate for convolutions (default is 1). + upscale_factor: Upscale factor. + latent_channels: Number of latent channels used in the block (default is 1). + activation: Configuration for the activation function used between layers (default is None). + enable_nhwc: Whether to enable NHWC format (default is False). + enable_healpixpad: Whether to enable HEALPixPadding (default is False). + """ + if in_channels == int(latent_channels): + self.skip_module = lambda x: x # Identity-function required in forward pass + else: + self.skip_module = HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + + # 1st ConvNeXt block, the output of this one remains internal + convblock = [] + # 3x3 convolution establishing latent channels channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=in_channels, + out_channels=int(latent_channels), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + # 1x1 convolution establishing increased channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=int(latent_channels * upscale_factor), + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + # 1x1 convolution returning to latent channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels * upscale_factor), + out_channels=int(latent_channels), + kernel_size=1, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + # 3x3 convolution from latent channels to latent channels + convblock.append( + HEALPixLayer( + layer=th.nn.Conv2d, + in_channels=int(latent_channels), + out_channels=out_channels, # int(latent_channels), + kernel_size=kernel_size, + dilation=dilation, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + convblock.append(activation.build()) + self.convblock = nn.Sequential(*convblock) + + def forward(self, x): + """Forward pass of the SymmetricConvNextBlock + Args: + x: inputs to the forward pass + Returns: + result of the forward pass + """ + # residual connection with reshaped inpute and output of conv block + return self.skip_module(x) + self.convblock(x) + + +class TransposedConvUpsample(nn.Module): + """Wrapper for upsampling with a transposed convolution using HEALPix or other tensor data. + + This class wraps the `nn.ConvTranspose2d` class to handle tensor data with + HEALPix or other geometry layers. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 1, + upsampling: int = 2, + activation: Optional[CappedGELUConfig] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Args: + in_channels: The number of input channels. + out_channels: The number of output channels. + upsampling: Size used for upsampling. + activation: ModuleConfig for the activation function used in upsampling. + enable_nhwc: Enable nhwc format, passed to wrapper. + enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper. + """ + super().__init__() + upsampler = [] + # Upsample transpose conv + upsampler.append( + HEALPixLayer( + layer=nn.ConvTranspose2d, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=upsampling, + stride=upsampling, + padding=0, + enable_nhwc=enable_nhwc, + enable_healpixpad=enable_healpixpad, + ) + ) + if activation is not None: + upsampler.append(activation.build()) + self.upsampler = nn.Sequential(*upsampler) + + def forward(self, x): + """Forward pass of the TransposedConvUpsample layer. + + Args: + x: The values to upsample. + + Returns: + th.Tensor: The upsampled values. + """ + return self.upsampler(x) + + +# Helpers + + +class Interpolate(nn.Module): + """Helper class for interpolation. + + This class handles interpolation, storing scale factor and mode for + `nn.functional.interpolate`. + """ + + def __init__(self, scale_factor: Union[int, Tuple], mode: str = "nearest"): + """ + Args: + scale_factor: Multiplier for spatial size, passed to `nn.functional.interpolate`. + mode, : Interpolation mode used for upsampling, passed to `nn.functional.interpolate`. + """ + super().__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, inputs): + """Forward pass of the Interpolate layer. + + Args: + inputs: Inputs to interpolate. + + Returns: + th.Tensor: The interpolated values. + """ + return self.interp(inputs, scale_factor=self.scale_factor, mode=self.mode) diff --git a/fme/fme/ace/models/healpix/healpix_decoder.py b/fme/fme/ace/models/healpix/healpix_decoder.py new file mode 100644 index 0000000..5dd5f15 --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_decoder.py @@ -0,0 +1,187 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from typing import List, Optional, Sequence + +import torch as th +import torch.nn as nn + +from .healpix_blocks import ConvBlockConfig, RecurrentBlockConfig + + +@dataclasses.dataclass +class UNetDecoderConfig: + """ + Configuration for the UNet Decoder. + + Parameters: + conv_block: Configuration for the convolutional block. + up_sampling_block: Configuration for the up-sampling block. + output_layer: Configuration for the output layer block. + recurrent_block: Configuration for the recurrent block, by default None. + n_channels: Number of channels for each layer, by default (34, 68, 136). + n_layers: Number of layers in each block, by default (1, 2, 2). + output_channels: Number of output channels, by default 1. + dilations: List of dilation rates for the layers, by default None. + enable_nhwc: Flag to enable NHWC data format, by default False. + enable_healpixpad: Flag to enable HEALPix padding, by default False. + """ + + conv_block: ConvBlockConfig + up_sampling_block: ConvBlockConfig + output_layer: ConvBlockConfig + recurrent_block: Optional[RecurrentBlockConfig] = None + n_channels: List[int] = dataclasses.field(default_factory=lambda: [34, 68, 136]) + n_layers: List[int] = dataclasses.field(default_factory=lambda: [1, 2, 2]) + output_channels: int = 1 + dilations: Optional[list] = None + enable_nhwc: bool = False + enable_healpixpad: bool = False + + def build(self) -> nn.Module: + """ + Builds the UNet Decoder model. + + Returns: + UNet Decoder model. + """ + return UNetDecoder( + conv_block=self.conv_block, + up_sampling_block=self.up_sampling_block, + output_layer=self.output_layer, + recurrent_block=self.recurrent_block, + n_channels=self.n_channels, + n_layers=self.n_layers, + output_channels=self.output_channels, + dilations=self.dilations, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + + +class UNetDecoder(nn.Module): + """Generic UNetDecoder that can be applied to arbitrary meshes.""" + + def __init__( + self, + conv_block: ConvBlockConfig, + up_sampling_block: ConvBlockConfig, + output_layer: ConvBlockConfig, + recurrent_block: Optional[RecurrentBlockConfig] = None, + n_channels: Sequence = (64, 32, 16), + n_layers: Sequence = (1, 2, 2), + output_channels: int = 1, + dilations: Optional[list] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Initialize the UNetDecoder. + + Args: + conv_block: Configuration for the convolutional block. + up_sampling_block: Configuration for the upsampling block. + output_layer: Configuration for the output layer. + recurrent_block: Configuration for the recurrent block. If None, recurrent blocks are not used. + n_channels: Sequence specifying the number of channels in each decoder layer. + n_layers: Sequence specifying the number of layers in each block. + output_channels: Number of output channels. + dilations: List of dilations to use for the convolutional blocks. + enable_nhwc: If True, use channel last format. + enable_healpixpad: If True, use the healpixpad library if installed. + """ + super().__init__() + self.channel_dim = 1 + + if dilations is None: + dilations = [1 for _ in range(len(n_channels))] + + self.decoder = [] + for n, curr_channel in enumerate(n_channels): + up_sample_module = None + if n != 0: + up_sampling_block.in_channels = curr_channel + up_sampling_block.out_channels = curr_channel + up_sampling_block.enable_nhwc = enable_nhwc + up_sampling_block.enable_healpixpad = enable_healpixpad + up_sample_module = up_sampling_block.build() + + next_channel = ( + n_channels[n + 1] if n < len(n_channels) - 1 else n_channels[-1] + ) + + conv_block.in_channels = curr_channel * 2 if n > 0 else curr_channel + conv_block.latent_channels = curr_channel + conv_block.out_channels = next_channel + conv_block.dilation = dilations[n] + conv_block.n_layers = n_layers[n] + conv_block.enable_nhwc = enable_nhwc + conv_block.enable_healpixpad = enable_healpixpad + conv_module = conv_block.build() + + rec_module = None + if recurrent_block is not None: + recurrent_block.in_channels = next_channel + recurrent_block.enable_healpixpad = enable_healpixpad + rec_module = recurrent_block.build() + + self.decoder.append( + nn.ModuleDict( + { + "upsamp": up_sample_module, + "conv": conv_module, + "recurrent": rec_module, + } + ) + ) + + self.decoder = nn.ModuleList(self.decoder) + + output_layer.in_channels = curr_channel + output_layer.out_channels = output_channels + output_layer.dilation = dilations[-1] + output_layer.enable_nhwc = enable_nhwc + output_layer.enable_healpixpad = enable_healpixpad + + self.output_layer = output_layer.build() + + def forward(self, inputs): + """ + Forward pass of the UNetDecoder. + + Args: + inputs: The inputs to the forward pass. + + Returns: + The decoded values. + """ + x = inputs[-1] + for n, layer in enumerate(self.decoder): + if layer["upsamp"] is not None: + up = layer["upsamp"](x) + x = th.cat([up, inputs[-1 - n]], dim=self.channel_dim) + x = layer["conv"](x) + if layer["recurrent"] is not None: + x = layer["recurrent"](x) + return self.output_layer(x) + + def reset(self): + """Resets the state of the decoder layers.""" + for layer in self.decoder: + if layer["recurrent"] is not None: + layer["recurrent"].reset() diff --git a/fme/fme/ace/models/healpix/healpix_encoder.py b/fme/fme/ace/models/healpix/healpix_encoder.py new file mode 100644 index 0000000..8471abb --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_encoder.py @@ -0,0 +1,149 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from typing import List, Optional, Sequence + +import torch.nn as nn + +from fme.ace.models.healpix.healpix_activations import DownsamplingBlockConfig + +from .healpix_blocks import ConvBlockConfig + + +@dataclasses.dataclass +class UNetEncoderConfig: + """ + Configuration for the UNet Encoder. + + Parameters: + conv_block: Configuration for the convolutional block. + down_sampling_block: Configuration for the down-sampling block. + input_channels: Number of input channels, by default 3. + n_channels: Number of channels for each layer, by default (136, 68, 34). + n_layers: Number of layers in each block, by default (2, 2, 1). + dilations: List of dilation rates for the layers, by default None. + enable_nhwc: Flag to enable NHWC data format, by default False. + enable_healpixpad: Flag to enable HEALPix padding, by default False. + """ + + conv_block: ConvBlockConfig + down_sampling_block: DownsamplingBlockConfig + input_channels: int = 3 + n_channels: List[int] = dataclasses.field(default_factory=lambda: [136, 68, 34]) + n_layers: List[int] = dataclasses.field(default_factory=lambda: [2, 2, 1]) + dilations: Optional[list] = None + enable_nhwc: bool = False + enable_healpixpad: bool = False + + def build(self) -> nn.Module: + """ + Builds the UNet Encoder model. + + Returns: + UNet Encoder model. + """ + return UNetEncoder( + conv_block=self.conv_block, + down_sampling_block=self.down_sampling_block, + input_channels=self.input_channels, + n_channels=self.n_channels, + n_layers=self.n_layers, + dilations=self.dilations, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) + + +class UNetEncoder(nn.Module): + """Generic UNetEncoder that can be applied to arbitrary meshes.""" + + def __init__( + self, + conv_block: ConvBlockConfig, + down_sampling_block: DownsamplingBlockConfig, + input_channels: int = 3, + n_channels: Sequence = (16, 32, 64), + n_layers: Sequence = (2, 2, 1), + dilations: Optional[list] = None, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + ): + """ + Args: + conv_block: config for the convolutional block + down_sampling_block: DownsamplingBlockConfig for the downsample block + input_channels: # of input channels + n_channels: # of channels in each encoder layer + n_layers:, # of layers to use for the convolutional blocks + dilations: list of dilations to use for the the convolutional blocks + enable_nhwc: if channel last format should be used + enable_healpixpad: if healpixpad library should be used (true if installed) + """ + super().__init__() + self.n_channels = n_channels + + if dilations is None: + # Defaults to [1, 1, 1...] in accordance with the number of unet levels + dilations = [1 for _ in range(len(n_channels))] + + # Build encoder + old_channels = input_channels + self.encoder = [] + for n, curr_channel in enumerate(n_channels): + modules = list() + if n > 0: + down_sampling_block.enable_nhwc = enable_nhwc + down_sampling_block.enable_healpixpad = enable_healpixpad + modules.append( + down_sampling_block.build() # Shapes are not used in these calls. + ) + + # Set up conv block + conv_block.in_channels = old_channels + conv_block.latent_channels = curr_channel + conv_block.out_channels = curr_channel + conv_block.dilation = dilations[n] + conv_block.n_layers = n_layers[n] + conv_block.enable_nhwc = enable_nhwc + conv_block.enable_healpixpad = enable_healpixpad + modules.append(conv_block.build()) # Shapes are not used in these calls. + old_channels = curr_channel + + self.encoder.append(nn.Sequential(*modules)) + + self.encoder = nn.ModuleList(self.encoder) + + def forward(self, inputs: Sequence) -> Sequence: + """ + Forward pass of the HEALPix Unet encoder + + Args: + inputs: The inputs to enccode + + Returns: + The encoded values + """ + outputs = [] + for layer in self.encoder: + outputs.append(layer(inputs)) + inputs = outputs[-1] + return outputs + + def reset(self): + """Resets the state of the decoder layers""" + pass diff --git a/fme/fme/ace/models/healpix/healpix_layers.py b/fme/fme/ace/models/healpix/healpix_layers.py new file mode 100644 index 0000000..edbeb39 --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_layers.py @@ -0,0 +1,493 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains padding and convolution classes to perform according operations on the twelve faces of the HEALPix. + + + HEALPix Face order 3D array representation + ----------------- +-------------------------- //\\ //\\ //\\ //\\ | | | | | +|| 0 | 1 | 2 | 3 || // \\// \\// \\// \\ |0 |1 |2 |3 | +|\\ //\\ //\\ //\\ //| /\\0 //\\1 //\\2 //\\3 // ----------------- +| \\// \\// \\// \\// | // \\// \\// \\// \\// | | | | | +|4//\\5 //\\6 //\\7 //\\4| \\4//\\5 //\\6 //\\7 //\\ |4 |5 |6 |7 | +|// \\// \\// \\// \\| \\/ \\// \\// \\// \\ ----------------- +|| 8 | 9 | 10 | 11 | \\8 //\\9 //\\10//\\11// | | | | | +-------------------------- \\// \\// \\// \\// |8 |9 |10 |11 | + ----------------- + "\\" are top and bottom, whereas + "//" are left and right borders + + +Details on the HEALPix can be found at https://iopscience.iop.org/article/10.1086/427976 + +""" + +import logging +import sys + +import torch as th +import torch.nn as nn + + +class HEALPixFoldFaces(nn.Module): + """Class that folds the faces of a HealPIX tensor""" + + def __init__(self, enable_nhwc: bool = False): + """ + Args: + enable_nhwc: Use nhwc format instead of nchw format + """ + super().__init__() + self.enable_nhwc = enable_nhwc + + def forward(self, tensor: th.Tensor) -> th.Tensor: + """ + Forward pass that folds a HEALPix tensor + [B, F, C, H, W] -> [B*F, C, H, W] + + Args: + tensor: The tensor to fold + + Returns: + th.Tensor: the folded tensor + + """ + N, F, C, H, W = tensor.shape + tensor = th.reshape(tensor, shape=(N * F, C, H, W)) + + if self.enable_nhwc: + tensor = tensor.to(memory_format=th.channels_last) + + return tensor + + +class HEALPixUnfoldFaces(nn.Module): + """Class that unfolds the faces of a HealPIX tensor""" + + def __init__(self, num_faces=12, enable_nhwc=False): + """ + Args: + num_faces: The number of faces on the grid, default 12 + enable_nhwc: If nhwc format is being used, default False + """ + super().__init__() + self.num_faces = num_faces + self.enable_nhwc = enable_nhwc + + def forward(self, tensor: th.Tensor) -> th.Tensor: + """ + Forward pass that unfolds a HEALPix tensor + [B*F, C, H, W] -> [B, F, C, H, W] + + Args: + tensor: the tensor to unfold + + Returns: + The unfolded tensor + + """ + NF, C, H, W = tensor.shape + N = int(NF / self.num_faces) + tensor = th.reshape(tensor, shape=(N, self.num_faces, C, H, W)) + + return tensor + + +class HEALPixPadding(nn.Module): + """ + Padding layer for data on a HEALPix sphere. The requirements for using this layer are as follows: + - The last three dimensions are (face=12, height, width) + - The first four indices in the faces dimension [0, 1, 2, 3] are the faces on the northern hemisphere + - The second four indices in the faces dimension [4, 5, 6, 7] are the faces on the equator + - The last four indices in the faces dimension [8, 9, 10, 11] are the faces on the southern hemisphere + + Orientation and arrangement of the HEALPix faces are outlined above. + """ + + def __init__(self, padding: int, enable_nhwc: bool = False): + """ + Args: + padding: The padding size + enable_nhwc: If nhwc format is being used, default False + """ + super().__init__() + self.p = padding + self.d = [-2, -1] + self.enable_nhwc = enable_nhwc + if not isinstance(padding, int) or padding < 1: + raise ValueError( + f"invalid value for 'padding', expected int > 0 but got {padding}" + ) + + self.fold = HEALPixFoldFaces(enable_nhwc=self.enable_nhwc) + self.unfold = HEALPixUnfoldFaces(num_faces=12, enable_nhwc=self.enable_nhwc) + + def forward(self, data: th.Tensor) -> th.Tensor: + """ + Pad each face consistently with its according neighbors in the HEALPix (see ordering and neighborhoods above). + Assumes the Tensor is folded + + Args: + data: The input tensor of shape [..., F, H, W] where each face is to be padded in its HPX context + + Returns: + The padded tensor where each face's height and width are increased by 2*p + """ + th.cuda.nvtx.range_push("HEALPixPadding:forward") + + # unfold faces from batch dim + data = self.unfold(data) + + # Extract the twelve faces (as views of the original tensors) + f00, f01, f02, f03, f04, f05, f06, f07, f08, f09, f10, f11 = [ + th.squeeze(x, dim=1) + for x in th.split(tensor=data, split_size_or_sections=1, dim=1) + ] + + # Assemble the four padded faces on the northern hemisphere + p00 = self.pn( + c=f00, t=f01, tl=f02, lft=f03, bl=f03, b=f04, br=f08, rgt=f05, tr=f01 + ) + p01 = self.pn( + c=f01, t=f02, tl=f03, lft=f00, bl=f00, b=f05, br=f09, rgt=f06, tr=f02 + ) + p02 = self.pn( + c=f02, t=f03, tl=f00, lft=f01, bl=f01, b=f06, br=f10, rgt=f07, tr=f03 + ) + p03 = self.pn( + c=f03, t=f00, tl=f01, lft=f02, bl=f02, b=f07, br=f11, rgt=f04, tr=f00 + ) + + # Assemble the four padded faces on the equator + p04 = self.pe( + c=f04, + t=f00, + tl=self.tl(f00, f03), + lft=f03, + bl=f07, + b=f11, + br=self.br(f11, f08), + rgt=f08, + tr=f05, + ) + p05 = self.pe( + c=f05, + t=f01, + tl=self.tl(f01, f00), + lft=f00, + bl=f04, + b=f08, + br=self.br(f08, f09), + rgt=f09, + tr=f06, + ) + p06 = self.pe( + c=f06, + t=f02, + tl=self.tl(f02, f01), + lft=f01, + bl=f05, + b=f09, + br=self.br(f09, f10), + rgt=f10, + tr=f07, + ) + p07 = self.pe( + c=f07, + t=f03, + tl=self.tl(f03, f02), + lft=f02, + bl=f06, + b=f10, + br=self.br(f10, f11), + rgt=f11, + tr=f04, + ) + + # Assemble the four padded faces on the southern hemisphere + p08 = self.ps( + c=f08, t=f05, tl=f00, lft=f04, bl=f11, b=f11, br=f10, rgt=f09, tr=f09 + ) + p09 = self.ps( + c=f09, t=f06, tl=f01, lft=f05, bl=f08, b=f08, br=f11, rgt=f10, tr=f10 + ) + p10 = self.ps( + c=f10, t=f07, tl=f02, lft=f06, bl=f09, b=f09, br=f08, rgt=f11, tr=f11 + ) + p11 = self.ps( + c=f11, t=f04, tl=f03, lft=f07, bl=f10, b=f10, br=f09, rgt=f08, tr=f08 + ) + + res = th.stack( + (p00, p01, p02, p03, p04, p05, p06, p07, p08, p09, p10, p11), dim=1 + ) + + # fold faces into batch dim + res = self.fold(res) + + th.cuda.nvtx.range_pop() + + return res + + def pn( + self, + c: th.Tensor, + t: th.Tensor, + tl: th.Tensor, + lft: th.Tensor, + bl: th.Tensor, + b: th.Tensor, + br: th.Tensor, + rgt: th.Tensor, + tr: th.Tensor, + ) -> th.Tensor: + """ + Applies padding to a northern hemisphere face c under consideration of its given neighbors. + + Args: + c: The central face and tensor that is subject for padding + t: The top neighboring face tensor + tl: The top left neighboring face tensor + lft: The left neighboring face tensor + bl: The bottom left neighboring face tensor + b: The bottom neighboring face tensor + br: The bottom right neighboring face tensor + rgt: The right neighboring face tensor + tr: The top right neighboring face tensor + + Returns: + The padded tensor p + """ + p = self.p # Padding size + d = self.d # Dimensions for rotations + + # Start with top and bottom to extend the height of the c tensor + c = th.cat((t.rot90(1, d)[..., -p:, :], c, b[..., :p, :]), dim=-2) + + # Construct the left and right pads including the corner faces + left = th.cat( + ( + tl.rot90(2, d)[..., -p:, -p:], + lft.rot90(-1, d)[..., -p:], + bl[..., :p, -p:], + ), + dim=-2, + ) + right = th.cat((tr[..., -p:, :p], rgt[..., :p], br[..., :p, :p]), dim=-2) + + return th.cat((left, c, right), dim=-1) + + def pe( + self, + c: th.Tensor, + t: th.Tensor, + tl: th.Tensor, + lft: th.Tensor, + bl: th.Tensor, + b: th.Tensor, + br: th.Tensor, + rgt: th.Tensor, + tr: th.Tensor, + ) -> th.Tensor: + """ + Applies padding to an equatorial face c under consideration of its given neighbors. + + Args: + c: The central face and tensor that is subject for padding + t: The top neighboring face tensor + tl: The top left neighboring face tensor + lft: The left neighboring face tensor + bl: The bottom left neighboring face tensor + b: The bottom neighboring face tensor + br: The bottom right neighboring face tensor + rgt: The right neighboring face tensor + tr: The top right neighboring face tensor + + Returns + ------- + th.Tensor: + The padded tensor p + """ + p = self.p # Padding size + + # Start with top and bottom to extend the height of the c tensor + c = th.cat((t[..., -p:, :], c, b[..., :p, :]), dim=-2) + + # Construct the left and right pads including the corner faces + left = th.cat((tl[..., -p:, -p:], lft[..., -p:], bl[..., :p, -p:]), dim=-2) + right = th.cat((tr[..., -p:, :p], rgt[..., :p], br[..., :p, :p]), dim=-2) + + return th.cat((left, c, right), dim=-1) + + def ps( + self, + c: th.Tensor, + t: th.Tensor, + tl: th.Tensor, + lft: th.Tensor, + bl: th.Tensor, + b: th.Tensor, + br: th.Tensor, + rgt: th.Tensor, + tr: th.Tensor, + ) -> th.Tensor: + """ + Applies padding to a southern hemisphere face c under consideration of its given neighbors. + + Args: + c: The central face and tensor that is subject for padding + t: The top neighboring face tensor + tl: The top left neighboring face tensor + lft: The left neighboring face tensor + bl: The bottom left neighboring face tensor + b: The bottom neighboring face tensor + br: The bottom right neighboring face tensor + rgt: The right neighboring face tensor + tr: The top right neighboring face tensor + + Returns: + The padded tensor p + """ + p = self.p # Padding size + d = self.d # Dimensions for rotations + + # Start with top and bottom to extend the height of the c tensor + c = th.cat((t[..., -p:, :], c, b.rot90(1, d)[..., :p, :]), dim=-2) + + # Construct the left and right pads including the corner faces + left = th.cat((tl[..., -p:, -p:], lft[..., -p:], bl[..., :p, -p:]), dim=-2) + right = th.cat( + (tr[..., -p:, :p], rgt.rot90(-1, d)[..., :p], br.rot90(2, d)[..., :p, :p]), + dim=-2, + ) + + return th.cat((left, c, right), dim=-1) + + def tl(self, top: th.Tensor, lft: th.Tensor) -> th.Tensor: + """ + Assembles the top left corner of a center face in the cases where no according top left face is defined on the + HPX. + + Args: + top: The face above the center face + lft: The face left of the center face + + Returns: + The assembled top left corner (only the sub-part that is required for padding) + """ + ret = th.zeros_like(top)[..., : self.p, : self.p] # super ugly but super fast + + # Bottom left point + ret[..., -1, -1] = 0.5 * top[..., -1, 0] + 0.5 * lft[..., 0, -1] + + # Remaining points + for i in range(1, self.p): + ret[..., -i - 1, -i:] = top[ + ..., -i - 1, :i + ] # Filling top right above main diagonal + ret[..., -i:, -i - 1] = lft[ + ..., :i, -i - 1 + ] # Filling bottom left below main diagonal + ret[..., -i - 1, -i - 1] = ( + 0.5 * top[..., -i - 1, 0] + 0.5 * lft[..., 0, -i - 1] + ) # Diagonal + + return ret + + def br(self, b: th.Tensor, r: th.Tensor) -> th.Tensor: + """ + Assembles the bottom right corner of a center face in the cases where no according bottom right face is defined + on the HPX. + + Args: + b: The face below the center face + r: The face right of the center face + + Returns: + The assembled bottom right corner (only the sub-part that is required for padding) + """ + ret = th.zeros_like(b)[..., : self.p, : self.p] + + # Top left point + ret[..., 0, 0] = 0.5 * b[..., 0, -1] + 0.5 * r[..., -1, 0] + + # Remaining points + for i in range(1, self.p): + ret[..., :i, i] = r[..., -i:, i] # Filling top right above main diagonal + ret[..., i, :i] = b[..., i, -i:] # Filling bottom left below main diagonal + ret[..., i, i] = 0.5 * b[..., i, -1] + 0.5 * r[..., -1, i] # Diagonal + + return ret + + +class HEALPixLayer(nn.Module): + """Pytorch module for applying any base torch Module on a HEALPix tensor. Expects all input/output tensors to have a + shape [..., 12, H, W], where 12 is the dimension of the faces. + """ + + def __init__(self, layer, **kwargs): + """ + Args: + layer: Any torch layer function, e.g., nn.Conv2d + kwargs: The arguments that are passed to the torch layer function, e.g., kernel_size + """ + super().__init__() + layers = [] + + if "enable_nhwc" in kwargs: + enable_nhwc = kwargs["enable_nhwc"] + del kwargs["enable_nhwc"] + else: + enable_nhwc = False + + if "enable_healpixpad" in kwargs and kwargs["enable_healpixpad"]: + raise NotImplementedError( + "HEALPixPaddingv2 is not available in this environment" + ) + + if "enable_healpixpad" in kwargs: + del kwargs["enable_healpixpad"] + + if not isinstance(layer, type) or not issubclass(layer, th.nn.Module): + raise TypeError( + f"Expected a subclass of torch.nn.Module, got {type(layer).__name__}" + ) + # Define a HEALPixPadding layer if the given layer is a convolution layer + if layer.__bases__[0] is nn.modules.conv._ConvNd and kwargs["kernel_size"] > 1: + kwargs["padding"] = 0 # Disable native padding + kernel_size = 3 if "kernel_size" not in kwargs else kwargs["kernel_size"] + dilation = 1 if "dilation" not in kwargs else kwargs["dilation"] + padding = ((kernel_size - 1) // 2) * dilation + layers.append(HEALPixPadding(padding=padding, enable_nhwc=enable_nhwc)) + + layers.append(layer(**kwargs)) + self.layers = nn.Sequential(*layers) + + if enable_nhwc: + self.layers = self.layers.to(memory_format=th.channels_last) + + def forward(self, x: th.Tensor) -> th.Tensor: + """ + Performs the forward pass using the defined layer function and the given data. + + :param x: The input tensor of shape [..., F=12, H, W] + :return: The output tensor of this HEALPix layer + """ + res = self.layers(x) + return res diff --git a/fme/fme/ace/models/healpix/healpix_recunet.py b/fme/fme/ace/models/healpix/healpix_recunet.py new file mode 100644 index 0000000..56b8d78 --- /dev/null +++ b/fme/fme/ace/models/healpix/healpix_recunet.py @@ -0,0 +1,436 @@ +# flake8: noqa +# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08 +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence + +import pandas as pd +import torch as th +import torch.nn as nn + +from .healpix_decoder import UNetDecoderConfig +from .healpix_encoder import UNetEncoderConfig +from .healpix_layers import HEALPixFoldFaces, HEALPixUnfoldFaces + + +class HEALPixRecUNet(nn.Module): + """Deep Learning Weather Prediction (DLWP) recurrent UNet model on the HEALPix mesh.""" + + def __init__( + self, + encoder: UNetEncoderConfig, + decoder: UNetDecoderConfig, + input_channels: int, + output_channels: int, + prognostic_variables: int, + n_constants: int, + decoder_input_channels: int, + input_time_size: int, + output_time_size: int, + delta_time: str = "6h", + reset_cycle: str = "24h", + presteps: int = 1, + enable_nhwc: bool = False, + enable_healpixpad: bool = False, + couplings: list = [], + ): + """ + Initialize the HEALPixRecUNet model. + + Args: + encoder: UNetEncoderConfig + ModuleConfig of instantiable parameters for the U-net encoder. + decoder: UNetDecoderConfig + ModuleConfig of instantiable parameters for the U-net decoder. + input_channels: int + Number of input channels expected in the input array schema. Note this should be the + number of input variables in the data, NOT including data reshaping for the encoder part. + output_channels: int + Number of output channels expected in the output array schema, or output variables. + n_constants: int + Number of optional constants expected in the input arrays. If this is zero, no constants + should be provided as inputs to forward. + decoder_input_channels: int + Number of optional prescribed variables expected in the decoder input array + for both inputs and outputs. If this is zero, no decoder inputs should be provided as inputs to forward. + input_time_size: int + Number of time steps in the input array. + output_time_size: int + Number of time steps in the output array. + delta_time: str, optional + Hours between two consecutive data points. + reset_cycle: str, optional + Hours after which the recurrent states are reset to zero and re-initialized. Set np.infty + to never reset the hidden states. + presteps: int, optional + Number of model steps to initialize recurrent states. + enable_nhwc: bool, optional + Model with [N, H, W, C] instead of [N, C, H, W]. + enable_healpixpad: bool, optional + Enable CUDA HEALPixPadding if installed. + couplings: list, optional + Sequence of dictionaries that describe coupling mechanisms. Currently unused in our production model; + but we want to keep this in the module definition, in case we bring our SST module + (which subclasses it) into the picture. + """ + super().__init__() + self.channel_dim = 2 # Now 2 with [B, F, T*C, H, W]. Was 1 in old data format with [B, T*C, F, H, W] + + self.input_channels = input_channels + + if n_constants == 0 and decoder_input_channels == 0: + pass + # raise NotImplementedError( + # "support for models with no constant fields and no decoder inputs (TOA insolation) is not available at this time." + # ) + if couplings is not None: + if len(couplings) > 0: + if n_constants == 0: + raise NotImplementedError( + "support for coupled models with no constant fields is not available at this time." + ) + if decoder_input_channels == 0: + raise NotImplementedError( + "support for coupled models with no decoder inputs (TOA insolation) is not available at this time." + ) + else: + couplings = [] + + # add coupled fields to input channels for model initialization + self.coupled_channels = self._compute_coupled_channels(couplings) + self.couplings = couplings + self.train_couplers = None + self.output_channels = output_channels + self.n_constants = n_constants + self.prognostic_variables = prognostic_variables + self.decoder_input_channels = decoder_input_channels + self.input_time_size = input_time_size + self.output_time_size = output_time_size + self.delta_t = int(pd.Timedelta(delta_time).total_seconds() // 3600) + self.reset_cycle = int(pd.Timedelta(reset_cycle).total_seconds() // 3600) + self.presteps = presteps + self.enable_nhwc = enable_nhwc + self.enable_healpixpad = enable_healpixpad + + # Number of passes through the model, or a diagnostic model with only one output time + self.is_diagnostic = self.output_time_size == 1 and self.input_time_size > 1 + if not self.is_diagnostic and ( + self.output_time_size % self.input_time_size != 0 + ): + raise ValueError( + f"'output_time_size' must be a multiple of 'input_time_size' (got " + f"{self.output_time_size} and {self.input_time_size})" + ) + + # Build the model layers + self.fold = HEALPixFoldFaces() + self.unfold = HEALPixUnfoldFaces(num_faces=12) + encoder.input_channels = self._compute_input_channels() + encoder.enable_nhwc = self.enable_nhwc + encoder.enable_healpixpad = self.enable_healpixpad + self.encoder = encoder.build() + + self.encoder_depth = len(self.encoder.n_channels) + decoder.output_channels = self._compute_output_channels() + decoder.enable_nhwc = self.enable_nhwc + decoder.enable_healpixpad = self.enable_healpixpad + self.decoder = decoder.build() + + @property + def integration_steps(self): + """Number of integration steps""" + return max(self.output_time_size // self.input_time_size, 1) + + def _compute_input_channels(self) -> int: + """ + Calculate total number of input channels in the model. + + Returns: + int: The total number of input channels. + """ + return ( + self.input_time_size * (self.input_channels + self.decoder_input_channels) + + self.n_constants + + self.coupled_channels + ) + + def _compute_coupled_channels(self, couplings): + """ + Get the number of coupled channels. + + Args: + couplings: list + Sequence of dictionaries that describe coupling mechanisms. + + Returns: + int: The number of coupled channels. + """ + c_channels = 0 + for c in couplings: + c_channels += len(c["params"]["variables"]) * len( + c["params"]["input_times"] + ) + return c_channels + + def _compute_output_channels(self) -> int: + """ + Compute the total number of output channels in the model. + + Returns: + int: The total number of output channels. + """ + return ( + 1 if self.is_diagnostic else self.input_time_size + ) * self.output_channels + + # def _reshape_inputs(self, inputs: Sequence, step: int = 0) -> th.Tensor: + # """ + # Returns a single tensor to pass into the model encoder/decoder. Squashes the time/channel dimension and + # concatenates in constants and decoder inputs. + + # Args: + # inputs: list of expected input tensors (inputs, decoder_inputs, constants) + # step: step number in the sequence of integration_steps + + # Returns: + # reshaped Tensor in expected shape for model encoder + # """ + + # # if len(self.couplings) > 0: + # # result = [ + # # inputs[0].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), + # # inputs[1][ + # # :, + # # :, + # # slice(step * self.input_time_size, (step + 1) * self.input_time_size), + # # ..., + # # ].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), # DI + # # inputs[2].expand( + # # *tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1]) + # # ), # constants + # # inputs[3].permute(0, 2, 1, 3, 4), # coupled inputs + # # ] + # # res = th.cat(result, dim=self.channel_dim) + + # # else: + # # if self.n_constants == 0: + # # result = [ # This logic changes for no insolation layer for the time being + # # inputs[0].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), + # # # inputs[0].flatten( + # # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # # ), + # # # inputs[1][ + # # # :, + # # # :, + # # # slice( + # # # step * self.input_time_size, (step + 1) * self.input_time_size + # # # ), + # # # ..., + # # # ].flatten( + # # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # # ), # DI + # # ] + # # res = th.cat(result, dim=self.channel_dim) + + # # # fold faces into batch dim + # # res = self.fold(res) + + # # return res + + # # if self.decoder_input_channels == 0: + # # result = [ + # # inputs[0].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), + # # inputs[1].expand( + # # *tuple([inputs[0].shape[0]] + len(inputs[1].shape) * [-1]) + # # ), # inputs + # # ] + # # print(f"result 1 is {result[0].shape}") + # # print(f"result 2 is {result[1].shape}") + + # # res = th.cat(result, dim=self.channel_dim) + + # # # fold faces into batch dim + # # res = self.fold(res) + + # # return res + + # # result = [ + # # inputs[0].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), + # # inputs[1][ + # # :, + # # :, + # # slice(step * self.input_time_size, (step + 1) * self.input_time_size), + # # ..., + # # ].flatten( + # # start_dim=self.channel_dim, end_dim=self.channel_dim + 1 + # # ), # DI + # # inputs[2].expand( + # # *tuple([inputs[0].shape[0]] + len(inputs[2].shape) * [-1]) + # # ), # constants + # # ] + # # res = th.cat(result, dim=self.channel_dim) + + # # fold faces into batch dim + # # res = self.fold(res) + # # res = self.fold(inputs) + # # return res + + # def _reshape_outputs(self, outputs: th.Tensor) -> th.Tensor: + # """Returns a multiple tensors to from the model decoder. + # Splits the time/channel dimensions. + + # Args: + # inputs: list of expected input tensors (inputs, decoder_inputs, constants) + # step: step number in the sequence of integration_steps + + # Returns: + # reshaped Tensor in expected shape for model outputs + # """ + # # unfold: + # outputs = self.unfold(outputs) + + # return outputs + + def _initialize_hidden( + self, inputs: Sequence, outputs: Sequence, step: int + ) -> None: + """Initialize the hidden layers + + Args: + inputs: Inputs to use to initialize the hideen layers + outputs: Outputs to use to initialize the hideen layers + step: Current step number of the initialization + """ + self.reset() + for prestep in range(self.presteps): + if step < self.presteps: + s = step + prestep + if len(self.couplings) > 0: + input_tensor = self.fold( + inputs=[ + inputs[0][ + :, + :, + s * self.input_time_size : (s + 1) + * self.input_time_size, + ] + ] + + list(inputs[1:3]) + + [inputs[3][prestep]], + step=step + prestep, + ) + else: + input_tensor = self.fold( + inputs=[ + inputs[0][ + :, + :, + s * self.input_time_size : (s + 1) + * self.input_time_size, + ] + ] + + list(inputs[1:]), + step=step + prestep, + ) + else: + s = step - self.presteps + prestep + if len(self.couplings) > 0: + input_tensor = self.fold( + inputs=[outputs[s - 1]] + + list(inputs[1:3]) + + [inputs[3][step - (prestep - self.presteps)]], + step=s + 1, + ) + else: + input_tensor = self.fold( + inputs=[outputs[s - 1]] + list(inputs[1:]), step=s + 1 + ) + self.decoder(self.encoder(input_tensor)) + + def forward(self, inputs: th.Tensor, output_only_last=False) -> th.Tensor: + """ + Forward pass of the HEALPixUnet + + Args: + inputs: Inputs to the model, which is currently in the form [B, F, C, H, W]. + (We assume that constants have been preprocessed to persist across batch) + (We also assume for now that the time is implied to be 1) + + Originally, this was expected to be a sequence of the form [prognostics|TISR|constants]: + [B, F, T, C, H, W] the format for prognostics and TISR + [F, C, H, W] the format for constants + + output_only_last: If only the last dimension of the outputs should + be returned + + Returns: + Predicted outputs + """ + self.reset() # will reset every step for now + # We need to make sure that the new input is the correct size, since we no longer have the ability + # to differentiate between the inputs, decoder inputs, and constants + if self._compute_input_channels() != inputs.shape[2]: + raise ValueError( + f"Expected input should have channels {self._compute_input_channels()}," + f" got {inputs.shape[2]}." + ) + + # The input logic gets really changed now that we just have a prognostic vars channel in the inputs. + # Basically we assume that all the batch-wise expansion for the inputs has already been done. + + # (Re-)initialize recurrent hidden states + # if (step * (self.delta_t * self.input_time_size)) % self.reset_cycle == 0: + # self._initialize_hidden(inputs=inputs, outputs=outputs, step=step) + # Skipping this for now. We assume a single input time step, and resetting every step. + # s = self.presteps + input_tensor = self.fold( + inputs + ) # Padding happens in HEALPixPadding, which will + # unfold this tensor to handle it. + + encodings = self.encoder(input_tensor) + decodings = self.decoder(encodings) + + # Residual prediction + n_prognostic_channels = self.prognostic_variables * self.input_time_size + prognostic_outputs = ( + input_tensor[:, :n_prognostic_channels] + + decodings[:, :n_prognostic_channels] + ) + + outputs_only = decodings[:, n_prognostic_channels:] + + reshaped = th.cat( + [self.unfold(prognostic_outputs), self.unfold(outputs_only)], + dim=self.channel_dim, + ) + + return reshaped + + def reset(self): + """Resets the state of the network""" + self.encoder.reset() + self.decoder.reset() diff --git a/fme/fme/ace/models/makani/sfnonet.py b/fme/fme/ace/models/makani/sfnonet.py index 4e33ec9..18367f6 100644 --- a/fme/fme/ace/models/makani/sfnonet.py +++ b/fme/fme/ace/models/makani/sfnonet.py @@ -17,6 +17,7 @@ import math from functools import partial +from typing import Tuple # for annotation of models import torch @@ -249,8 +250,8 @@ def __init__( sht_grid_type="legendre-gauss", filter_type="linear", operator_type="dhconv", - inp_shape=(721, 1440), - out_shape=(721, 1440), + inp_shape: Tuple[int, int] = (721, 1440), + out_shape: Tuple[int, int] = (721, 1440), scale_factor=8, inp_chans=2, out_chans=2, @@ -424,9 +425,7 @@ def __init__( torch.zeros(1, embed_dim, self.inp_shape_loc[0], self.inp_shape_loc[1]) ) # information about how tensors are shared / sharded across ranks - self.pos_embed.is_shared_mp = ( - [] - ) # no reduction required since pos_embed is already serial + self.pos_embed.is_shared_mp = [] # no reduction required since pos_embed is already serial self.pos_embed.sharded_dims_mp = [None, None, "h", "w"] self.pos_embed.type = "direct" with torch.no_grad(): @@ -500,10 +499,10 @@ def _init_spectral_transforms( ifft_handle = InverseRealFFT2 self.trans_down = fft_handle( - *self.inp_shape, lmax=modes_lat, mmax=modes_lon + self.inp_shape[0], self.inp_shape[1], lmax=modes_lat, mmax=modes_lon ).float() self.itrans_up = ifft_handle( - *self.out_shape, lmax=modes_lat, mmax=modes_lon + self.out_shape[0], self.out_shape[1], lmax=modes_lat, mmax=modes_lon ).float() self.trans = fft_handle( self.h, self.w, lmax=modes_lat, mmax=modes_lon diff --git a/fme/fme/ace/models/modulus/layers.py b/fme/fme/ace/models/modulus/layers.py index 22eddfe..81ec16b 100644 --- a/fme/fme/ace/models/modulus/layers.py +++ b/fme/fme/ace/models/modulus/layers.py @@ -22,10 +22,9 @@ import torch.nn.functional as F from torch.cuda import amp from torch.utils.checkpoint import checkpoint -from torch_harmonics import * -from .activations import * -from .contractions import * +from .activations import ComplexReLU +from .contractions import compl_mul2d_fwd, compl_muladd2d_fwd @torch.jit.script @@ -199,185 +198,6 @@ def forward(self, x): # pragma: no cover return out -class SpectralConv2d(nn.Module): - """ - Spectral Convolution as utilized in - """ - - def __init__( - self, - forward_transform, - inverse_transform, - hidden_size, - sparsity_threshold=0.0, - hard_thresholding_fraction=1, - use_complex_kernels=False, - compression=None, - rank=0, - bias=False, - ): # pragma: no cover - super(SpectralConv2d, self).__init__() - - self.hidden_size = hidden_size - self.sparsity_threshold = sparsity_threshold - self.hard_thresholding_fraction = hard_thresholding_fraction - self.scale = 1 / hidden_size**2 - self.contract_handle = ( - compl_contract2d_fwd_c if use_complex_kernels else compl_contract2d_fwd - ) - - self.forward_transform = forward_transform - self.inverse_transform = inverse_transform - - self.output_dims = (self.inverse_transform.nlat, self.inverse_transform.nlon) - modes_lat = self.inverse_transform.lmax - modes_lon = self.inverse_transform.mmax - self.modes_lat = int(modes_lat * self.hard_thresholding_fraction) - self.modes_lon = int(modes_lon * self.hard_thresholding_fraction) - - # new simple linear layer - self.w = nn.Parameter( - self.scale - * torch.randn( - self.hidden_size, self.hidden_size, self.modes_lat, self.modes_lon, 2 - ) - ) - # optional bias - if bias: - self.b = nn.Parameter( - self.scale * torch.randn(1, self.hidden_size, *self.output_dims) - ) - - def forward(self, x): # pragma: no cover - dtype = x.dtype - # x = x.float() - B, C, H, W = x.shape - - with amp.autocast(enabled=False): - x = x.to(torch.float32) - x = self.forward_transform(x) - x = torch.view_as_real(x) - x = x.to(dtype) - - # do spectral conv - modes = torch.zeros(x.shape, device=x.device) - - # modes[:, :, :self.modes_lat, :self.modes_lon, :] = self.contract_handle(x[:, :, :self.modes_lat, :self.modes_lon, :], self.wh) - # modes[:, :, -self.modes_lat:, :self.modes_lon, :] = self.contract_handle(x[:, :, -self.modes_lat:, :self.modes_lon, :], self.wl) - modes = self.contract_handle(x, self.w) - - # finalize - x = F.softshrink(modes, lambd=self.sparsity_threshold) - x = torch.view_as_complex(x) - - with amp.autocast(enabled=False): - x = x.to(torch.float32) - x = torch.view_as_complex(x) - x = x.contiguous() - x = self.inverse_transform(x) - x = x.to(dtype) - - if hasattr(self, "b"): - x = x + self.b - - return x - - -class SpectralConvS2(nn.Module): - """ - Spectral Convolution as utilized in - """ - - def __init__( - self, - forward_transform, - inverse_transform, - hidden_size, - sparsity_threshold=0.0, - use_complex_kernels=False, - compression=None, - rank=128, - bias=False, - ): # pragma: no cover - super(SpectralConvS2, self).__init__() - - self.hidden_size = hidden_size - self.sparsity_threshold = sparsity_threshold - self.scale = 0.02 - - self.forward_transform = forward_transform - self.inverse_transform = inverse_transform - - self.modes_lat = self.forward_transform.lmax - self.modes_lon = self.forward_transform.mmax - - assert self.inverse_transform.lmax == self.modes_lat - assert self.inverse_transform.mmax == self.modes_lon - - # remember the lower triangular indices - ii, jj = torch.tril_indices(self.modes_lat, self.modes_lon) - self.register_buffer("ii", ii) - self.register_buffer("jj", jj) - - if compression == "tt": - self.rank = rank - # tensortrain coefficients - g1 = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.rank, 2)) - g2 = nn.Parameter( - self.scale * torch.randn(self.rank, self.hidden_size, self.rank, 2) - ) - g3 = nn.Parameter(self.scale * torch.randn(self.rank, len(ii), 2)) - self.w = nn.ParameterList([g1, g2, g3]) - - self.contract_handle = ( - contract_tt # if use_complex_kernels else raise(NotImplementedError) - ) - else: - self.w = nn.Parameter( - self.scale * torch.randn(self.hidden_size, self.hidden_size, len(ii), 2) - ) - self.contract_handle = ( - compl_contract_fwd_c if use_complex_kernels else compl_contract_fwd - ) - - if bias: - self.b = nn.Parameter( - self.scale * torch.randn(1, self.hidden_size, *self.output_dims) - ) - - def forward(self, x): # pragma: no cover - dtype = x.dtype - # x = x.float() - B, C, H, W = x.shape - - with amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.contiguous() - x = self.forward_transform(x) - x = torch.view_as_real(x) - x = x.to(dtype) - - # do spectral conv - modes = torch.zeros(x.shape, device=x.device) - modes[:, :, self.ii, self.jj, :] = self.contract_handle( - x[:, :, self.ii, self.jj, :], self.w - ) - - # finalize - x = F.softshrink(modes, lambd=self.sparsity_threshold) - - with amp.autocast(enabled=False): - x = x.to(torch.float32) - x = torch.view_as_complex(x) - x = self.inverse_transform(x) - x = x.to(dtype) - - if hasattr(self, "b"): - x = x + self.b - - return x - - class SpectralAttention2d(nn.Module): """ 2d Spectral Attention layer @@ -404,10 +224,10 @@ def __init__( self.hidden_size = int(hidden_size_factor * self.embed_dim) self.scale = 0.02 self.spectral_layers = spectral_layers - self.mul_add_handle = ( - compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd - ) - self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd + if use_complex_kernels: + raise NotImplementedError("complex kernels not supported") + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd self.modes_lat = forward_transform.lmax self.modes_lon = forward_transform.mmax @@ -512,12 +332,12 @@ def __init__( self.sparsity_threshold = sparsity_threshold self.hidden_size = int(hidden_size_factor * self.embed_dim) self.scale = 0.02 + if use_complex_kernels: + raise NotImplementedError("complex kernels not supported") # self.mul_add_handle = compl_muladd1d_fwd_c if use_complex_kernels else compl_muladd1d_fwd - self.mul_add_handle = ( - compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd - ) + self.mul_add_handle = compl_muladd2d_fwd # self.mul_handle = compl_mul1d_fwd_c if use_complex_kernels else compl_mul1d_fwd - self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd + self.mul_handle = compl_mul2d_fwd self.spectral_layers = spectral_layers self.modes_lat = forward_transform.lmax diff --git a/fme/fme/ace/models/modulus/sfnonet.py b/fme/fme/ace/models/modulus/sfnonet.py index 0c0434b..1f8806c 100644 --- a/fme/fme/ace/models/modulus/sfnonet.py +++ b/fme/fme/ace/models/modulus/sfnonet.py @@ -232,9 +232,9 @@ def forward(self, x): x = self.act_layer(x) x_norm = torch.zeros_like(x) - x_norm[ - ..., : self.output_shape_loc[0], : self.output_shape_loc[1] - ] = self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]]) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1(x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]]) + ) x = x_norm if hasattr(self, "mlp"): diff --git a/fme/fme/ace/registry/__init__.py b/fme/fme/ace/registry/__init__.py index 12c781f..31b7898 100644 --- a/fme/fme/ace/registry/__init__.py +++ b/fme/fme/ace/registry/__init__.py @@ -1,6 +1,6 @@ # import modules so they are registered from . import prebuilt as _prebuilt from . import sfno as _sfno -from .registry import ModuleSelector, get_from_registry, register +from .registry import ModuleSelector del _prebuilt, _sfno diff --git a/fme/fme/ace/registry/hpx.py b/fme/fme/ace/registry/hpx.py new file mode 100644 index 0000000..e90f42c --- /dev/null +++ b/fme/fme/ace/registry/hpx.py @@ -0,0 +1,78 @@ +import dataclasses +from typing import Tuple + +import torch.nn as nn + +from fme.ace.models.healpix.healpix_decoder import UNetDecoderConfig +from fme.ace.models.healpix.healpix_encoder import UNetEncoderConfig +from fme.ace.models.healpix.healpix_recunet import HEALPixRecUNet +from fme.ace.registry.registry import ModuleConfig, ModuleSelector + + +@ModuleSelector.register("HEALPixRecUNet") +@dataclasses.dataclass +class HEALPixRecUNetBuilder(ModuleConfig): + """ + Configuration for the HEALPixRecUNet architecture used in DLWP. + + Parameters: + presteps: Number of pre-steps, by default 1. + input_time_size: Input time dimension, by default 0. + output_time_size: Output time dimension, by default 0. + delta_time: Delta time interval, by default "6h". + reset_cycle: Reset cycle interval, by default "24h". + input_channels: Number of input channels, by default 8. + output_channels: Number of output channels, by default 8. + n_constants: Number of constant input channels, by default 2. + decoder_input_channels: Number of input channels for the decoder, by default 1. + enable_nhwc: Flag to enable NHWC data format, by default False. + enable_healpixpad: Flag to enable HEALPix padding, by default False. + """ + + encoder: UNetEncoderConfig + decoder: UNetDecoderConfig + presteps: int = 1 + input_time_size: int = 0 + output_time_size: int = 0 + delta_time: str = "6h" + reset_cycle: str = "24h" + n_constants: int = 2 + decoder_input_channels: int = 1 + prognostic_variables: int = 7 + enable_nhwc: bool = False + enable_healpixpad: bool = False + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + """ + Builds the HEALPixRecUNet model. + + Args: + n_in_channels: Number of input channels. + n_out_channels: Number of output channels. + img_shape: Shape of the input image. + + Returns: + HEALPixRecUNet model. + """ + # Construct the HEALPixRecUNet module here using the parameters + return HEALPixRecUNet( + encoder=self.encoder, + decoder=self.decoder, + input_channels=n_in_channels, + output_channels=n_out_channels, + prognostic_variables=self.prognostic_variables, + n_constants=self.n_constants, + decoder_input_channels=self.decoder_input_channels, + input_time_size=self.input_time_size, + output_time_size=self.output_time_size, + delta_time=self.delta_time, + reset_cycle=self.reset_cycle, + presteps=self.presteps, + enable_nhwc=self.enable_nhwc, + enable_healpixpad=self.enable_healpixpad, + ) diff --git a/fme/fme/ace/registry/prebuilt.py b/fme/fme/ace/registry/prebuilt.py index f8bebdc..880df5e 100644 --- a/fme/fme/ace/registry/prebuilt.py +++ b/fme/fme/ace/registry/prebuilt.py @@ -3,10 +3,10 @@ from torch import nn -from fme.ace.registry.registry import ModuleConfig, register +from fme.ace.registry.registry import ModuleConfig, ModuleSelector -@register("prebuilt") +@ModuleSelector.register("prebuilt") @dataclasses.dataclass class PreBuiltBuilder(ModuleConfig): """ diff --git a/fme/fme/ace/registry/registry.py b/fme/fme/ace/registry/registry.py index 2486c8b..3fa9bd1 100644 --- a/fme/fme/ace/registry/registry.py +++ b/fme/fme/ace/registry/registry.py @@ -1,6 +1,4 @@ -from fme.core.registry import ( # noqa: F401 +from fme.core.registry.module import ( # noqa: F401 ModuleConfig, ModuleSelector, - get_from_registry, - register, ) diff --git a/fme/fme/ace/registry/sfno.py b/fme/fme/ace/registry/sfno.py index 0031ea7..f915055 100644 --- a/fme/fme/ace/registry/sfno.py +++ b/fme/fme/ace/registry/sfno.py @@ -5,12 +5,12 @@ SphericalFourierNeuralOperatorNet as MakaniSFNO, ) from fme.ace.models.modulus.sfnonet import SphericalFourierNeuralOperatorNet -from fme.ace.registry.registry import ModuleConfig, register +from fme.ace.registry.registry import ModuleConfig, ModuleSelector # this is based on the call signature of SphericalFourierNeuralOperatorNet at # https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501 -@register("SphericalFourierNeuralOperatorNet") +@ModuleSelector.register("SphericalFourierNeuralOperatorNet") @dataclasses.dataclass class SphericalFourierNeuralOperatorBuilder(ModuleConfig): """ @@ -55,7 +55,7 @@ def build( return sfno_net -@register("SFNO-v0.1.0") +@ModuleSelector.register("SFNO-v0.1.0") @dataclasses.dataclass class SFNO_V0_1_0(ModuleConfig): """ diff --git a/fme/fme/ace/registry/test_hpx.py b/fme/fme/ace/registry/test_hpx.py new file mode 100644 index 0000000..9a77522 --- /dev/null +++ b/fme/fme/ace/registry/test_hpx.py @@ -0,0 +1,830 @@ +import dataclasses +import datetime +import logging +from typing import Tuple, Union + +import numpy as np +import pytest +import torch as th + +from fme.ace.models.healpix.healpix_activations import ( + CappedGELUConfig, + DownsamplingBlockConfig, +) +from fme.ace.models.healpix.healpix_blocks import ConvBlockConfig, RecurrentBlockConfig +from fme.ace.models.healpix.healpix_decoder import UNetDecoder +from fme.ace.models.healpix.healpix_encoder import UNetEncoder +from fme.ace.models.healpix.healpix_recunet import HEALPixRecUNet +from fme.ace.registry.hpx import UNetDecoderConfig, UNetEncoderConfig +from fme.ace.stepper import SingleModuleStepperConfig +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations +from fme.core.normalizer import NormalizationConfig + +TIMESTEP = datetime.timedelta(hours=6) +logger = logging.getLogger("__name__") + + +def fix_random_seeds(seed=0): + """Fix random seeds for reproducibility""" + np.random.seed(seed) + th.manual_seed(seed) + th.cuda.manual_seed(seed) + + +def conv_next_block_config(in_channels=3, out_channels=1): + activation_block_config = CappedGELUConfig(cap_value=10) + conv_next_block_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + activation=activation_block_config, + kernel_size=3, + dilation=1, + upscale_factor=4, + block_type="ConvNeXtBlock", + ) + return conv_next_block_config + + +def down_sampling_block_config(): + return DownsamplingBlockConfig(pooling=2, block_type="AvgPool") + + +def encoder_config( + conv_next_block_config, down_sampling_block_config, n_channels=[136, 68, 34] +): + return UNetEncoderConfig( + conv_block=conv_next_block_config, + down_sampling_block=down_sampling_block_config, + n_channels=n_channels, + dilations=[1, 2, 4], + ) + + +def up_sampling_block_config(in_channels=3, out_channels=1): + activation_block_config = CappedGELUConfig(cap_value=10) + transposed_conv_upsample_block_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + activation=activation_block_config, + upsampling=2, + block_type="TransposedConvUpsample", + ) + return transposed_conv_upsample_block_config + + +def output_layer_config(in_channels=3, out_channels=2): + conv_block_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + dilation=1, + n_layers=1, + block_type="BasicConvBlock", + ) + return conv_block_config + + +def recurrent_block_config(in_channels=3): + recurrent_block_config = RecurrentBlockConfig( + in_channels=in_channels, + kernel_size=1, + block_type="ConvGRUBlock", + ) + return recurrent_block_config + + +def decoder_config( + conv_next_block_config, + up_sampling_block_config, + output_layer_config, + recurrent_block_config, + n_channels=[34, 68, 136], +): + decoder_config = UNetDecoderConfig( + conv_block=conv_next_block_config, + up_sampling_block=up_sampling_block_config, + recurrent_block=recurrent_block_config, + output_layer=output_layer_config, + n_channels=n_channels, + dilations=[4, 2, 1], + ) + return decoder_config + + +def _test_data(): + # create dummy data + def generate_test_data(batch_size=8, time_dim=1, channels=7, img_size=16): + device = get_device() + test_data = th.randn(batch_size, 12, time_dim * channels, img_size, img_size) + return test_data.to(device) + + return generate_test_data + + +def constant_data(): + # create dummy data + def generate_constant_data(channels=2, img_size=16): + device = get_device() + constants = th.randn(12, channels, img_size, img_size) + + return constants.to(device) + + return generate_constant_data + + +def insolation_data(): + # create dummy data + def generate_insolation_data(batch_size=8, time_dim=1, img_size=16): + device = get_device() + insolation = th.randn(batch_size, 12, time_dim, img_size, img_size) + + return insolation.to(device) + + return generate_insolation_data + + +@pytest.mark.parametrize( + "shape", + [ + pytest.param((8, 16)), + ], +) +def test_hpx_init(shape): + in_channels = 7 + out_channels = 7 + prognostic_variables = min(in_channels, out_channels) + n_constants = 1 + decoder_input_channels = 1 + input_time_size = 2 + output_time_size = 4 + device = get_device() + + conv_next_block = conv_next_block_config() + down_sampling_block = down_sampling_block_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config(conv_next_block, down_sampling_block) + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + decoder = decoder_config( + conv_next_block, up_sampling_block, output_layer, recurrent_block + ) + + hpx_config_data = { + "type": "HEALPixRecUNet", + "config": { + "encoder": dataclasses.asdict(encoder), + "decoder": dataclasses.asdict(decoder), + "prognostic_variables": prognostic_variables, + "n_constants": n_constants, + "decoder_input_channels": decoder_input_channels, + "input_time_size": input_time_size, + "output_time_size": output_time_size, + }, + } + + stepper_config_data = { + "builder": hpx_config_data, + "in_names": ["x"], + "out_names": ["x"], + "normalization": dataclasses.asdict( + NormalizationConfig( + means={"x": float(np.random.randn(1).item())}, + stds={"x": float(np.random.randn(1).item())}, + ) + ), + } + area = th.ones((1, 16, 32)).to(device) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=th.arange(7), bk=th.arange(7) + ).to(device) + stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data) + stepper = stepper_config.get_stepper( + img_shape=shape, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + ) + assert type(stepper.module.module) is HEALPixRecUNet + + +@pytest.mark.parametrize( + "in_channels, out_channels, n_constants, decoder_input_channels, input_time_size, \ + output_time_size, couplings, expected_exception, expected_message", + [ + (7, 7, 1, 1, 2, 4, None, None, None), # Valid case + ( + 7, + 7, + 1, + 1, + 2, + 3, + None, + ValueError, + "'output_time_size' must be a multiple of 'input_time_size'", + ), # Bad input and output time dims + ( + 7, + 7, + 0, + 2, + 2, + 3, + ["t2m", "v10m"], + NotImplementedError, + "support for coupled models with no constant field", + ), # Couplings with no constants + ( + 7, + 7, + 2, + 0, + 2, + 3, + ["t2m", "v10m"], + NotImplementedError, + "support for coupled models with no decoder", + ), # Couplings with no decoder input channels + ( + 7, + 7, + 0, + 0, + 2, + 3, + None, + ValueError, + "'output_time_size' must be a multiple of 'input_time_size'", + ), # No constant fields and no decoder + ], +) +def test_HEALPixRecUNet_initialize( + in_channels, + out_channels, + n_constants, + decoder_input_channels, + input_time_size, + output_time_size, + couplings, + expected_exception, + expected_message, +): + prognostic_variables = min(out_channels, in_channels) + conv_next_block = conv_next_block_config() + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config(conv_next_block, down_sampling_block_config()) + decoder = decoder_config( + conv_next_block, up_sampling_block, output_layer, recurrent_block + ) + device = get_device() + + if expected_exception: + with pytest.raises(expected_exception, match=expected_message): + model = HEALPixRecUNet( + encoder=encoder, + decoder=decoder, + input_channels=in_channels, + output_channels=out_channels, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + couplings=couplings, + ).to(device) + else: + model = HEALPixRecUNet( + encoder=encoder, + decoder=decoder, + input_channels=in_channels, + output_channels=out_channels, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + couplings=couplings, + ).to(device) + assert isinstance(model, HEALPixRecUNet) + + +def test_HEALPixRecUNet_integration_steps(): + in_channels = 2 + out_channels = 2 + prognostic_variables = min(out_channels, in_channels) + n_constants = 1 + decoder_input_channels = 0 + input_time_size = 2 + output_time_size = 4 + device = get_device() + + conv_next_block = conv_next_block_config() + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config(conv_next_block, down_sampling_block_config()) + decoder = decoder_config( + conv_next_block, up_sampling_block, output_layer, recurrent_block + ) + + model = HEALPixRecUNet( + encoder=encoder, + decoder=decoder, + input_channels=in_channels, + output_channels=out_channels, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + ).to(device) + + assert model.integration_steps == output_time_size // input_time_size + + +def test_HEALPixRecUNet_reset(very_fast_only: bool): + if very_fast_only: + pytest.skip("Skipping non-fast tests") + # create a smaller version of the dlwp healpix model + in_channels = 3 + out_channels = 3 + prognostic_variables = min(out_channels, in_channels) + n_constants = 2 + decoder_input_channels = 1 + input_time_size = 2 + output_time_size = 4 + size = 16 + device = get_device() + + conv_next_block = conv_next_block_config() + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config(conv_next_block, down_sampling_block_config()) + decoder = decoder_config( + conv_next_block, up_sampling_block, output_layer, recurrent_block + ) + + fix_random_seeds(seed=42) + x = _test_data()(time_dim=input_time_size, channels=in_channels, img_size=size) + decoder_inputs = insolation_data()(time_dim=input_time_size, img_size=size) + constants = constant_data()(channels=n_constants, img_size=size) + batch_size = x.shape[0] + constants = constants.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + inputs = th.concat( + (x, decoder_inputs, constants), dim=-3 + ) # [x, decoder_inputs, constants] + + model = HEALPixRecUNet( + encoder=encoder, + decoder=decoder, + input_channels=in_channels, + output_channels=out_channels, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + enable_healpixpad=False, + delta_time="6h", + ).to(device) + + out_var = model(inputs) + model.reset() + + assert compare_output(out_var, model(inputs)) + + +# Checks the model can perform a forward class on various input configurations +# [full inputs, no decoder inputs, no constant inputs] +@pytest.mark.parametrize( + "inputs_config, in_channels, decoder_input_channels, \ + out_channels, input_time_size, output_time_size, n_constants, size", + [ + ([0, 1, 2], 3, 1, 3, 2, 4, 2, 16), # full inputs + ([0, 2], 3, 0, 3, 2, 4, 2, 16), # no decoder inputs + ([0, 1], 3, 1, 3, 2, 4, 0, 16), # no constant inputs + ], +) +def test_HEALPixRecUNet_forward( + inputs_config, + in_channels, + decoder_input_channels, + out_channels, + input_time_size, + output_time_size, + n_constants, + size, + very_fast_only: bool, +): + if very_fast_only: + pytest.skip("Skipping non-fast tests") + prognostic_variables = min(out_channels, in_channels) + device = get_device() + conv_next_block = conv_next_block_config() + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config(conv_next_block, down_sampling_block_config()) + decoder = decoder_config( + conv_next_block, up_sampling_block, output_layer, recurrent_block + ) + + fix_random_seeds(seed=42) + x = _test_data()(time_dim=input_time_size, channels=in_channels, img_size=size) + batch_size = x.shape[0] + + if decoder_input_channels > 0: + decoder_inputs = insolation_data()(time_dim=input_time_size, img_size=size) + else: + decoder_inputs = insolation_data()(time_dim=0, img_size=size) + constants = constant_data()(channels=n_constants, img_size=size) + constants = constants.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + all_inputs = [x, decoder_inputs, constants] + inputs = th.concat(all_inputs, dim=-3) + + model = HEALPixRecUNet( + encoder=encoder, + decoder=decoder, + input_channels=in_channels, + output_channels=out_channels, + prognostic_variables=prognostic_variables, + n_constants=n_constants, + decoder_input_channels=decoder_input_channels, + input_time_size=input_time_size, + output_time_size=output_time_size, + enable_healpixpad=False, + delta_time="6h", + ).to(device) + model(inputs) + + +# pragma mark - encoder + + +def test_UNetEncoder_initialize(): + device = get_device() + channels = 2 + n_channels = (16, 32, 64) + + # Dicts for block configs used by encoder + conv_block_config = ConvBlockConfig( + in_channels=channels, + block_type="ConvNeXtBlock", + ) + down_sampling_block_config = DownsamplingBlockConfig( + pooling=2, block_type="MaxPool" + ) + + encoder = UNetEncoder( + conv_block=conv_block_config, + down_sampling_block=down_sampling_block_config, + n_channels=n_channels, + input_channels=channels, + ).to(device) + assert isinstance(encoder, UNetEncoder) + + # with dilations + encoder = UNetEncoder( + conv_block=conv_block_config, + down_sampling_block=down_sampling_block_config, + n_channels=n_channels, + input_channels=channels, + dilations=[1, 1, 1], + ).to(device) + assert isinstance(encoder, UNetEncoder) + + +def test_UNetEncoder_forward(): + channels = 2 + hw_size = 16 + b_size = 12 + n_channels = (16, 32, 64) + device = get_device() + + # block configs used by encoder + conv_block_config = ConvBlockConfig( + in_channels=channels, + block_type="ConvNeXtBlock", + ) + down_sampling_block_config = DownsamplingBlockConfig( + pooling=2, block_type="MaxPool" + ) + encoder = UNetEncoder( + conv_block=conv_block_config, + down_sampling_block=down_sampling_block_config, + n_channels=n_channels, + input_channels=channels, + ).to(device) + + tensor_size = [b_size, channels, hw_size, hw_size] + invar = th.rand(tensor_size).to(device) + outvar = encoder(invar) + + # doesn't do anything + encoder.reset() + + # outvar is a module list + for idx, out_tensor in enumerate(outvar): + # verify the channels and h dim are correct + assert out_tensor.shape[1] == n_channels[idx] + # default behaviour is to half the h/w size after first + assert out_tensor.shape[2] == tensor_size[2] // (2**idx) + + +def test_UNetEncoder_reset(): + channels = 2 + n_channels = (16, 32, 64) + device = get_device() + + # Dicts for block configs used by encoder + conv_block_config = ConvBlockConfig( + in_channels=channels, + block_type="ConvNeXtBlock", + ) + down_sampling_block_config = DownsamplingBlockConfig( + pooling=2, + block_type="MaxPool", + ) + encoder = UNetEncoder( + conv_block=conv_block_config, + down_sampling_block=down_sampling_block_config, + n_channels=n_channels, + input_channels=channels, + ).to(device) + + # doesn't do anything + encoder.reset() + assert isinstance(encoder, UNetEncoder) + + +def test_UNetDecoder_initilization(): + in_channels = 2 + out_channels = 1 + n_channels = (64, 32, 16) + device = get_device() + + # Dicts for block configs used by decoder + conv_block_config = ConvBlockConfig( + in_channels=in_channels, out_channels=out_channels, block_type="ConvNeXtBlock" + ) + up_sampling_block_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + upsampling=2, + block_type="TransposedConvUpsample", + ) + + output_layer_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + dilation=1, + n_layers=1, + block_type="ConvNeXtBlock", + ) + + recurrent_block_config = RecurrentBlockConfig( + in_channels=2, + kernel_size=1, + block_type="ConvGRUBlock", + ) + + decoder = UNetDecoder( + conv_block=conv_block_config, + up_sampling_block=up_sampling_block_config, + output_layer=output_layer_config, + recurrent_block=recurrent_block_config, + n_channels=n_channels, + ).to(device) + + assert isinstance(decoder, UNetDecoder) + + # without the recurrent block and with dilations + decoder = UNetDecoder( + conv_block=conv_block_config, + up_sampling_block=up_sampling_block_config, + output_layer=output_layer_config, + recurrent_block=None, + n_channels=n_channels, + dilations=[1, 1, 1], + ).to(device) + assert isinstance(decoder, UNetDecoder) + + +def test_UNetDecoder_forward(): + in_channels = 2 + out_channels = 1 + hw_size = 32 + b_size = 12 + n_channels = (64, 32, 16) + device = get_device() + + # Dicts for block configs used by decoder + conv_block_config = ConvBlockConfig( + in_channels=in_channels, out_channels=out_channels, block_type="ConvNeXtBlock" + ) + up_sampling_block_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + upsampling=2, + block_type="TransposedConvUpsample", + ) + output_layer_config = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + dilation=1, + n_layers=1, + block_type="BasicConvBlock", + ) + recurrent_block_config = RecurrentBlockConfig( + in_channels=2, + kernel_size=1, + block_type="ConvGRUBlock", + ) + + decoder = UNetDecoder( + conv_block=conv_block_config, + up_sampling_block=up_sampling_block_config, + output_layer=output_layer_config, + recurrent_block=recurrent_block_config, + n_channels=n_channels, + ).to(device) + + output_2_size = th.Size([b_size, out_channels, hw_size, hw_size]) + + # build the list of tensors for the decoder + invars = [] + # decoder has an algorithm that goes back to front + for idx in range(len(n_channels) - 1, -1, -1): + tensor_size = [b_size, n_channels[idx], hw_size, hw_size] + invars.append(th.rand(tensor_size).to(device)) + hw_size = hw_size // 2 + + outvar = decoder(invars) + assert outvar.shape == output_2_size + + # make sure history is taken into account with ConvGRU + outvar_hist = decoder(invars) + assert not compare_output(outvar, outvar_hist) + + # check with no recurrent + decoder = UNetDecoder( + conv_block=conv_block_config, + up_sampling_block=up_sampling_block_config, + output_layer=output_layer_config, + recurrent_block=None, + n_channels=n_channels, + dilations=[1, 1, 1], + ).to(device) + + outvar = decoder(invars) + assert outvar.shape == output_2_size + + +def test_UNetDecoder_reset(): + in_channels = 2 + out_channels = 1 + hw_size = 32 + b_size = 12 + n_channels = (64, 32, 16) + device = get_device() + + # Dicts for block configs used by decoder + conv_block = ConvBlockConfig(in_channels=in_channels, block_type="ConvNeXtBlock") + up_sampling_block = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + block_type="TransposedConvUpsample", + ) + output_layer = ConvBlockConfig( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + dilation=1, + n_layers=1, + block_type="BasicConvBlock", + ) + + recurrent_block = RecurrentBlockConfig( + in_channels=2, kernel_size=1, block_type="ConvLSTMBlock" + ) + + decoder = UNetDecoder( + conv_block=conv_block, + up_sampling_block=up_sampling_block, + output_layer=output_layer, + recurrent_block=recurrent_block, + n_channels=n_channels, + ).to(device) + + # build the list of tensors for the decoder + invars = [] + # decoder has an algorithm that goes back to front + for idx in range(len(n_channels) - 1, -1, -1): + tensor_size = [b_size, n_channels[idx], hw_size, hw_size] + invars.append(th.rand(tensor_size).to(device)) + hw_size = hw_size // 2 + + outvar = decoder(invars) + + # make sure history is taken into account with ConvGRU + outvar_hist = decoder(invars) + assert not compare_output(outvar, outvar_hist) + + # make sure after reset we get the same result + decoder.reset() + outvar_reset = decoder(invars) + assert compare_output(outvar, outvar_reset) + + # test reset without recurrent block + decoder = UNetDecoder( + conv_block=conv_block, + up_sampling_block=up_sampling_block, + output_layer=output_layer, + recurrent_block=None, + n_channels=n_channels, + ).to(device) + + outvar = decoder(invars) + + # without the recurrent block should be the same + outvar_hist = decoder(invars) + assert compare_output(outvar, outvar_hist) + + # make sure after reset we get the same result + decoder.reset() + outvar_reset = decoder(invars) + assert compare_output(outvar, outvar_reset) + + +def compare_output( + output_1: Union[th.Tensor, Tuple[th.Tensor, ...]], + output_2: Union[th.Tensor, Tuple[th.Tensor, ...]], + rtol: float = 1e-5, + atol: float = 1e-5, +) -> bool: + """Compares model outputs and returns if they are the same + + Args + output_1: First item to compare + output_2: Second item to compare + rtol: Relative tolerance of error allowed, by default 1e-5 + atol: Absolute tolerance of error allowed, by default 1e-5 + + Returns: + If outputs are the same + """ + # Output of tensor + if isinstance(output_1, th.Tensor): + return th.allclose(output_1, output_2, rtol, atol) + # Output of tuple of tensors + elif isinstance(output_1, tuple): + # Loop through tuple of outputs + for i, (out_1, out_2) in enumerate(zip(output_1, output_2)): + # If tensor use allclose + if isinstance(out_1, th.Tensor): + if not th.allclose(out_1, out_2, rtol, atol): + logger.warning(f"Failed comparison between outputs {i}") + logger.warning(f"Max Difference: {th.amax(th.abs(out_1 - out_2))}") + logger.warning(f"Difference: {out_1 - out_2}") + return False + # Otherwise assume primative + else: + if not out_1 == out_2: + return False + elif isinstance(output_1, (list, tuple)) and isinstance(output_2, (list, tuple)): + if len(output_1) != len(output_2): + print( + f"Length mismatch: output_1 {len(output_1)}, output_2 {len(output_2)}" + ) + return False + for a, e in zip(output_1, output_2): + if not compare_output(a, e): + return False + return True + elif isinstance(output_1, dict) and isinstance(output_2, dict): + if output_1.keys() != output_2.keys(): + print( + f"Keys mismatch: output_1 keys {output_1.keys()}, ", + f"output_2 keys {output_2.keys()}", + ) + return False + for key in output_1: + if not compare_output(output_1[key], output_2[key]): + return False + return True + # Unsupported output type + else: + logger.error( + "Model returned invalid type for unit test, \ + should be th.Tensor or Tuple[th.Tensor]" + ) + return False + return True diff --git a/fme/fme/ace/registry/test_sfno.py b/fme/fme/ace/registry/test_sfno.py index 046a039..95e7610 100644 --- a/fme/fme/ace/registry/test_sfno.py +++ b/fme/fme/ace/registry/test_sfno.py @@ -5,10 +5,11 @@ import pytest import torch -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.ace.stepper import SingleModuleStepperConfig +from fme.core.coordinates import HybridSigmaPressureCoordinate from fme.core.device import get_device -from fme.core.normalizer import FromStateNormalizer -from fme.core.stepper import SingleModuleStepperConfig +from fme.core.gridded_ops import LatLonOperations +from fme.core.normalizer import NormalizationConfig TIMESTEP = datetime.timedelta(hours=6) @@ -34,23 +35,21 @@ def test_sfno_init(shape): "in_names": ["x"], "out_names": ["x"], "normalization": dataclasses.asdict( - FromStateNormalizer( - state={ - "means": {"x": float(np.random.randn(1))}, - "stds": {"x": float(np.random.randn(1))}, - } + NormalizationConfig( + means={"x": float(np.random.randn(1).item())}, + stds={"x": float(np.random.randn(1).item())}, ) ), } area = torch.ones((1, 16, 32)).to(get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to( - get_device() - ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ).to(get_device()) stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data) stepper = stepper_config.get_stepper( img_shape=shape, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) assert len(stepper.module.module.blocks) == num_layers diff --git a/fme/fme/ace/requirements.py b/fme/fme/ace/requirements.py new file mode 100644 index 0000000..4c31070 --- /dev/null +++ b/fme/fme/ace/requirements.py @@ -0,0 +1,16 @@ +import dataclasses +from typing import List + + +@dataclasses.dataclass +class PrognosticStateDataRequirements: + """ + The requirements for the model's prognostic state. + + Parameters: + names: Names of prognostic variables. + n_timesteps: Number of consecutive timesteps that must be stored. + """ + + names: List[str] + n_timesteps: int diff --git a/fme/fme/ace/stepper.py b/fme/fme/ace/stepper.py new file mode 100644 index 0000000..830ae14 --- /dev/null +++ b/fme/fme/ace/stepper.py @@ -0,0 +1,955 @@ +import dataclasses +import datetime +import logging +from copy import copy +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + +import dacite +import torch +import xarray as xr +from torch import nn + +from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState +from fme.ace.inference.derived_variables import compute_derived_quantities +from fme.ace.requirements import PrognosticStateDataRequirements +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.corrector.corrector import CorrectorConfig +from fme.core.dataset.requirements import DataRequirements +from fme.core.dataset.utils import decode_timestep, encode_timestep +from fme.core.device import get_device +from fme.core.distributed import Distributed +from fme.core.generics.inference import PredictFunction +from fme.core.generics.optimization import OptimizationABC +from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC +from fme.core.gridded_ops import GriddedOperations, LatLonOperations +from fme.core.loss import WeightedMappingLossConfig +from fme.core.normalizer import NormalizationConfig, StandardNormalizer +from fme.core.ocean import Ocean, OceanConfig +from fme.core.optimization import NullOptimization +from fme.core.packer import Packer +from fme.core.parameter_init import ParameterInitializationConfig +from fme.core.registry import CorrectorSelector, ModuleSelector +from fme.core.timing import GlobalTimer +from fme.core.typing_ import TensorDict, TensorMapping + +DEFAULT_TIMESTEP = datetime.timedelta(hours=6) +DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP) + + +class AtmosphericDeriveFn: + def __init__( + self, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ): + self.vertical_coordinate = vertical_coordinate.to( + "cpu" + ) # must be on cpu for multiprocessing fork context + self.timestep = timestep + + def __call__(self, data: TensorMapping, forcing_data: TensorMapping) -> TensorDict: + return compute_derived_quantities( + dict(data), + vertical_coordinate=self.vertical_coordinate.to(get_device()), + timestep=self.timestep, + forcing_data=dict(forcing_data), + ) + + +@dataclasses.dataclass +class SingleModuleStepperConfig: + """ + Configuration for a single module stepper. + + Parameters: + builder: The module builder. + in_names: Names of input variables. + out_names: Names of output variables. + normalization: The normalization configuration. + parameter_init: The parameter initialization configuration. + ocean: The ocean configuration. + loss: The loss configuration. + corrector: The corrector configuration. + next_step_forcing_names: Names of forcing variables for the next timestep. + loss_normalization: The normalization configuration for the loss. + residual_normalization: Optional alternative to configure loss normalization. + If provided, it will be used for all *prognostic* variables in loss scaling. + """ + + builder: ModuleSelector + in_names: List[str] + out_names: List[str] + normalization: NormalizationConfig + parameter_init: ParameterInitializationConfig = dataclasses.field( + default_factory=lambda: ParameterInitializationConfig() + ) + ocean: Optional[OceanConfig] = None + loss: WeightedMappingLossConfig = dataclasses.field( + default_factory=lambda: WeightedMappingLossConfig() + ) + corrector: Union[CorrectorConfig, CorrectorSelector] = dataclasses.field( + default_factory=lambda: CorrectorConfig() + ) + next_step_forcing_names: List[str] = dataclasses.field(default_factory=list) + loss_normalization: Optional[NormalizationConfig] = None + residual_normalization: Optional[NormalizationConfig] = None + + def __post_init__(self): + for name in self.next_step_forcing_names: + if name not in self.in_names: + raise ValueError( + f"next_step_forcing_name '{name}' not in in_names: {self.in_names}" + ) + if name in self.out_names: + raise ValueError( + f"next_step_forcing_name is an output variable: '{name}'" + ) + if ( + self.residual_normalization is not None + and self.loss_normalization is not None + ): + raise ValueError( + "Only one of residual_normalization, loss_normalization can " + "be provided." + "If residual_normalization is provided, it will be used for all " + "*prognostic* variables in loss scalng. " + "If loss_normalization is provided, it will be used for all variables " + "in loss scaling." + ) + + @property + def n_ic_timesteps(self) -> int: + return 1 + + def get_evaluation_window_data_requirements( + self, n_forward_steps: int + ) -> DataRequirements: + return DataRequirements( + names=self.all_names, + n_timesteps=self._window_steps_required(n_forward_steps), + ) + + def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: + return PrognosticStateDataRequirements( + names=self.prognostic_names, + n_timesteps=self.n_ic_timesteps, + ) + + def get_forcing_window_data_requirements( + self, n_forward_steps: int + ) -> DataRequirements: + if self.ocean is None: + names = self.forcing_names + else: + names = list(set(self.forcing_names).union(self.ocean.forcing_names)) + + return DataRequirements( + names=names, + n_timesteps=self._window_steps_required(n_forward_steps), + ) + + def _window_steps_required(self, n_forward_steps: int) -> int: + return n_forward_steps + self.n_ic_timesteps + + def get_state(self): + return dataclasses.asdict(self) + + def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: + """ + If the model is being initialized from another model's weights for fine-tuning, + returns those weights. Otherwise, returns None. + + The list mirrors the order of `modules` in the `SingleModuleStepper` class. + """ + base_weights = self.parameter_init.get_base_weights() + if base_weights is not None: + return [base_weights] + else: + return None + + def get_stepper( + self, + img_shape: Tuple[int, int], + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ): + logging.info("Initializing stepper from provided config") + derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep) + return SingleModuleStepper( + config=self, + img_shape=img_shape, + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + derive_func=derive_func, + ) + + @classmethod + def from_state(cls, state) -> "SingleModuleStepperConfig": + state = cls.remove_deprecated_keys(state) + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + @property + def all_names(self): + """Names of all variables required, including auxiliary ones.""" + extra_names = [] + if self.ocean is not None: + extra_names.extend(self.ocean.forcing_names) + all_names = list(set(self.in_names).union(self.out_names).union(extra_names)) + return all_names + + @property + def normalize_names(self): + """Names of variables which require normalization. I.e. inputs/outputs.""" + return list(set(self.in_names).union(self.out_names)) + + @property + def forcing_names(self) -> List[str]: + """Names of variables which are inputs only.""" + return list(set(self.in_names) - set(self.out_names)) + + @property + def prognostic_names(self) -> List[str]: + """Names of variables which both inputs and outputs.""" + return list(set(self.out_names).intersection(self.in_names)) + + @property + def diagnostic_names(self) -> List[str]: + """Names of variables which both inputs and outputs.""" + return list(set(self.out_names).difference(self.in_names)) + + @classmethod + def remove_deprecated_keys(cls, state: Dict[str, Any]) -> Dict[str, Any]: + _unsupported_key_defaults = { + "conserve_dry_air": False, + "optimization": None, + "conservation_loss": {"dry_air_penalty": None}, + } + state_copy = state.copy() + for key, default in _unsupported_key_defaults.items(): + if key in state_copy: + if state_copy[key] == default or state_copy[key] is None: + del state_copy[key] + else: + raise ValueError( + f"The stepper config option {key} is deprecated and the setting" + f" provided, {state_copy[key]}, is no longer implemented. The " + "SingleModuleStepper being loaded from state cannot be run by " + "this version of the code." + ) + for normalization_key in [ + "normalization", + "loss_normalization", + "residual_normalization", + ]: + if state_copy.get(normalization_key) is not None: + if "exclude_names" in state_copy[normalization_key]: + if state_copy[normalization_key]["exclude_names"] is not None: + raise ValueError( + "The exclude_names option in normalization config is no " + "longer supported, but excluded names were found in " + f"{normalization_key}." + ) + else: + del state_copy[normalization_key]["exclude_names"] + if "prescriber" in state_copy: + # want to maintain backwards compatibility for this particular feature + if state_copy["prescriber"] is not None: + if state_copy.get("ocean") is not None: + raise ValueError("Cannot specify both prescriber and ocean.") + state_copy["ocean"] = { + "surface_temperature_name": state_copy["prescriber"][ + "prescribed_name" + ], + "ocean_fraction_name": state_copy["prescriber"]["mask_name"], + "interpolate": state_copy["prescriber"]["interpolate"], + } + del state_copy["prescriber"] + return state_copy + + +@dataclasses.dataclass +class ExistingStepperConfig: + """ + Configuration for an existing stepper. This is only designed to point to + a serialized stepper checkpoint for loading, e.g., in the case of training + resumption. + + Parameters: + checkpoint_path: The path to the serialized checkpoint. + """ + + checkpoint_path: str + + def __post_init__(self): + self._stepper_config = SingleModuleStepperConfig.from_state( + self._load_checkpoint()["stepper"]["config"] + ) + + def _load_checkpoint(self) -> Mapping[str, Any]: + return torch.load(self.checkpoint_path, map_location=get_device()) + + def get_evaluation_window_data_requirements( + self, n_forward_steps: int + ) -> DataRequirements: + return self._stepper_config.get_evaluation_window_data_requirements( + n_forward_steps + ) + + def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: + return self._stepper_config.get_prognostic_state_data_requirements() + + def get_forcing_window_data_requirements( + self, n_forward_steps: int + ) -> DataRequirements: + return self._stepper_config.get_forcing_window_data_requirements( + n_forward_steps + ) + + def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: + return self._stepper_config.get_base_weights() + + def get_stepper(self, img_shape, gridded_operations, vertical_coordinate, timestep): + del img_shape # unused + logging.info(f"Initializing stepper from {self.checkpoint_path}") + return SingleModuleStepper.from_state(self._load_checkpoint()["stepper"]) + + +def _combine_normalizers( + residual_normalizer: StandardNormalizer, + model_normalizer: StandardNormalizer, +) -> StandardNormalizer: + # Combine residual and model normalizers by overwriting the model normalizer + # values that are present in residual normalizer. The residual normalizer + # is assumed to have a subset of prognostic keys only. + means, stds = copy(model_normalizer.means), copy(model_normalizer.stds) + means.update(residual_normalizer.means) + stds.update(residual_normalizer.stds) + return StandardNormalizer(means=means, stds=stds) + + +def _prepend_timesteps( + data: TensorMapping, timesteps: TensorMapping, time_dim: int = 1 +) -> TensorDict: + return {k: torch.cat([timesteps[k], v], dim=time_dim) for k, v in data.items()} + + +@dataclasses.dataclass +class TrainOutput(TrainOutputABC): + metrics: TensorDict + gen_data: TensorDict + target_data: TensorDict + time: xr.DataArray + normalize: Callable[[TensorDict], TensorDict] + derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = ( + lambda x, _: dict(x) + ) + + def remove_initial_condition(self, n_ic_timesteps: int) -> "TrainOutput": + return TrainOutput( + metrics=self.metrics, + gen_data={k: v[:, n_ic_timesteps:] for k, v in self.gen_data.items()}, + target_data={k: v[:, n_ic_timesteps:] for k, v in self.target_data.items()}, + time=self.time[:, n_ic_timesteps:], + normalize=self.normalize, + derive_func=self.derive_func, + ) + + def copy(self) -> "TrainOutput": + """Creates new dictionaries for the data but with the same tensors.""" + return TrainOutput( + metrics=self.metrics, + gen_data={k: v for k, v in self.gen_data.items()}, + target_data={k: v for k, v in self.target_data.items()}, + time=self.time, + normalize=self.normalize, + derive_func=self.derive_func, + ) + + def prepend_initial_condition( + self, + initial_condition: PrognosticState, + ) -> "TrainOutput": + """ + Prepends an initial condition to the existing stepped data. + Assumes data are on the same device. + For data windows > 0, the target IC is different from the generated IC + and may be provided for correct calculation of tendencies. + + Args: + initial_condition: Initial condition data. + """ + batch_data = initial_condition.as_batch_data() + return TrainOutput( + metrics=self.metrics, + gen_data=_prepend_timesteps(self.gen_data, batch_data.data), + target_data=_prepend_timesteps( + self.target_data, + batch_data.data, + ), + time=xr.concat([batch_data.time, self.time], dim="time"), + normalize=self.normalize, + derive_func=self.derive_func, + ) + + def compute_derived_variables( + self, + ) -> "TrainOutput": + gen_data = self.derive_func(self.gen_data, self.target_data) + target_data = self.derive_func(self.target_data, self.target_data) + return TrainOutput( + metrics=self.metrics, + gen_data=gen_data, + target_data=target_data, + time=self.time, + normalize=self.normalize, + derive_func=self.derive_func, + ) + + def get_metrics(self) -> TensorDict: + return self.metrics + + +class SingleModuleStepper( + TrainStepperABC[ + PrognosticState, + BatchData, + BatchData, + PairedData, + TrainOutput, + ], +): + """ + Stepper class for a single pytorch module. + """ + + TIME_DIM = 1 + CHANNEL_DIM = -3 + + def __init__( + self, + config: SingleModuleStepperConfig, + img_shape: Tuple[int, int], + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + derive_func: Callable[[TensorMapping, TensorMapping], TensorDict], + timestep: datetime.timedelta, + init_weights: bool = True, + ): + """ + Args: + config: The configuration. + img_shape: Shape of domain as (n_lat, n_lon). + gridded_operations: The gridded operations, e.g. for area weighting. + vertical_coordinate: The vertical coordinate. + derive_func: Function to compute derived variables. + timestep: Timestep of the model. + init_weights: Whether to initialize the weights. Should pass False if + the weights are about to be overwritten by a checkpoint. + """ + self._gridded_operations = gridded_operations # stored for serializing + n_in_channels = len(config.in_names) + n_out_channels = len(config.out_names) + self.in_packer = Packer(config.in_names) + self.out_packer = Packer(config.out_names) + self.normalizer = config.normalization.build(config.normalize_names) + if config.ocean is not None: + self.ocean: Optional[Ocean] = config.ocean.build( + config.in_names, config.out_names, timestep + ) + else: + self.ocean = None + self.module = config.builder.build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + img_shape=img_shape, + ) + module, self._l2_sp_tuning_regularizer = config.parameter_init.apply( + self.module, init_weights=init_weights + ) + self.module = module.to(get_device()) + self.derive_func = derive_func + self._img_shape = img_shape + self._config = config + self._no_optimization = NullOptimization() + + dist = Distributed.get_instance() + self._is_distributed = dist.is_distributed() + self.module = dist.wrap_module(self.module) + + self._vertical_coordinates = vertical_coordinate.to(get_device()) + self._timestep = timestep + + self.loss_obj = config.loss.build( + gridded_operations.area_weighted_mean, config.out_names, self.CHANNEL_DIM + ) + + self._corrector = config.corrector.build( + gridded_operations=gridded_operations, + vertical_coordinate=self.vertical_coordinate, + timestep=timestep, + ) + if config.loss_normalization is not None: + self.loss_normalizer = config.loss_normalization.build( + names=config.normalize_names + ) + elif config.residual_normalization is not None: + # Use residual norm for prognostic variables and input/output + # normalizer for diagnostic variables in loss + self.loss_normalizer = _combine_normalizers( + residual_normalizer=config.residual_normalization.build( + config.prognostic_names + ), + model_normalizer=self.normalizer, + ) + else: + self.loss_normalizer = self.normalizer + self.in_names = config.in_names + self.out_names = config.out_names + + _1: PredictFunction[ # for type checking + PrognosticState, + BatchData, + BatchData, + ] = self.predict + + _2: PredictFunction[ # for type checking + PrognosticState, + BatchData, + PairedData, + ] = self.predict_paired + + @property + def vertical_coordinate(self) -> HybridSigmaPressureCoordinate: + return self._vertical_coordinates + + @property + def timestep(self) -> datetime.timedelta: + return self._timestep + + @property + def surface_temperature_name(self) -> Optional[str]: + if self._config.ocean is not None: + return self._config.ocean.surface_temperature_name + return None + + @property + def ocean_fraction_name(self) -> Optional[str]: + if self._config.ocean is not None: + return self._config.ocean.ocean_fraction_name + return None + + @property + def effective_loss_scaling(self) -> TensorDict: + """ + Effective loss scalings used to normalize outputs before computing loss. + y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i + where loss_scaling_i = loss_normalizer_std_i / weight_i. + """ + custom_weights = self._config.loss.weights + loss_normalizer_stds = self.loss_normalizer.stds + return { + k: loss_normalizer_stds[k] / custom_weights.get(k, 1.0) + for k in self._config.out_names + } + + def replace_ocean(self, ocean: Ocean): + """ + Replace the ocean model with a new one. + + Args: + ocean: The new ocean model. + """ + self.ocean = ocean + + @property + def forcing_names(self) -> List[str]: + """Names of variables which are inputs only.""" + return self._config.forcing_names + + @property + def prognostic_names(self) -> List[str]: + return sorted( + list(set(self.out_packer.names).intersection(self.in_packer.names)) + ) + + @property + def diagnostic_names(self) -> List[str]: + return sorted(list(set(self.out_packer.names).difference(self.in_packer.names))) + + @property + def n_ic_timesteps(self) -> int: + return 1 + + @property + def modules(self) -> nn.ModuleList: + """ + Returns: + A list of modules being trained. + """ + return nn.ModuleList([self.module]) + + def step( + self, + input: TensorMapping, + next_step_forcing_data: TensorMapping, + ) -> TensorDict: + """ + Step the model forward one timestep given input data. + + Args: + input: Mapping from variable name to tensor of shape + [n_batch, n_lat, n_lon]. This data is used as input for `self.module` + and is assumed to contain all input variables and be denormalized. + next_step_forcing_data: Mapping from variable name to tensor of shape + [n_batch, n_lat, n_lon]. This must contain the necessary forcing + data at the output timestep for the ocean model and corrector. + + Returns: + The denormalized output data at the next time step. + """ + input_norm = self.normalizer.normalize(input) + input_tensor = self.in_packer.pack(input_norm, axis=self.CHANNEL_DIM) + output_tensor = self.module(input_tensor) + output_norm = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM) + output = self.normalizer.denormalize(output_norm) + if self._corrector is not None: + output = self._corrector(input, output, next_step_forcing_data) + if self.ocean is not None: + output = self.ocean(input, output, next_step_forcing_data) + return output + + def _predict( + self, + initial_condition: TensorMapping, + forcing_data: TensorMapping, + n_forward_steps: int, + ) -> TensorDict: + """ + Predict multiple steps forward given initial condition and forcing data. + + Uses low-level inputs and does not compute derived variables, to separate + concerns from the public `predict` method. + + Args: + initial_condition: The initial condition, containing tensors of shape + [n_batch, self.n_ic_timesteps, ]. + forcing_data: The forcing data, containing tensors of shape + [n_batch, n_forward_steps + self.n_ic_timesteps, ]. + n_forward_steps: The number of forward steps to predict, corresponding + to the data shapes of forcing_data. + + Returns: + The output data at each timestep. + """ + state = { + k: initial_condition[k].squeeze(self.TIME_DIM) for k in initial_condition + } + ml_forcing_names = self._config.forcing_names + output_list = [] + for step in range(n_forward_steps): + ml_input_forcing = { + k: ( + forcing_data[k][:, step] + if k not in self._config.next_step_forcing_names + else forcing_data[k][:, step + 1] + ) + for k in ml_forcing_names + } + next_step_forcing_data = { + k: forcing_data[k][:, step + 1] for k in self._forcing_names() + } + input_data = {**state, **ml_input_forcing} + state = self.step(input_data, next_step_forcing_data) + output_list.append(state) + output_timeseries = {} + for name in state: + output_timeseries[name] = torch.stack( + [x[name] for x in output_list], dim=self.TIME_DIM + ) + return output_timeseries + + def predict( + self, + initial_condition: PrognosticState, + forcing: BatchData, + compute_derived_variables: bool = False, + ) -> Tuple[BatchData, PrognosticState]: + """ + Predict multiple steps forward given initial condition and reference data. + + Args: + initial_condition: Prognostic state data with tensors of shape + [n_batch, self.n_ic_timesteps, ]. This data is assumed + to contain all prognostic variables and be denormalized. + forcing: Contains tensors of shape + [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This + contains the forcing and ocean data for the initial condition and all + subsequent timesteps. + compute_derived_variables: Whether to compute derived variables for the + prediction. + + Returns: + A batch data containing the prediction and the prediction's final state + which can be used as a new initial condition. + """ + timer = GlobalTimer.get_instance() + with timer.context("forward_prediction"): + forcing_data = forcing.subset_names(self._forcing_names()) + initial_condition_state = initial_condition.as_batch_data() + if initial_condition_state.time.shape[1] != self.n_ic_timesteps: + raise ValueError( + f"Initial condition must have {self.n_ic_timesteps} timesteps, got " + f"{initial_condition_state.time.shape[1]}." + ) + n_forward_steps = forcing_data.time.shape[1] - self.n_ic_timesteps + output_timeseries = self._predict( + initial_condition_state.data, forcing_data.data, n_forward_steps + ) + data = BatchData.new_on_device( + output_timeseries, + forcing_data.time[:, self.n_ic_timesteps :], + horizontal_dims=forcing_data.horizontal_dims, + ) + if compute_derived_variables: + with timer.context("compute_derived_variables"): + data = ( + data.prepend(initial_condition) + .compute_derived_variables( + derive_func=self.derive_func, + forcing_data=forcing_data, + ) + .remove_initial_condition(self.n_ic_timesteps) + ) + return data, data.get_end(self.prognostic_names, self.n_ic_timesteps) + + def predict_paired( + self, + initial_condition: PrognosticState, + forcing: BatchData, + compute_derived_variables: bool = False, + ) -> Tuple[PairedData, PrognosticState]: + """ + Predict multiple steps forward given initial condition and reference data. + + Args: + initial_condition: Prognostic state data with tensors of shape + [n_batch, self.n_ic_timesteps, ]. This data is assumed + to contain all prognostic variables and be denormalized. + forcing: Contains tensors of shape + [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This + contains the forcing and ocean data for the initial condition and all + subsequent timesteps. + compute_derived_variables: Whether to compute derived variables for the + prediction. + + Returns: + A paired data containing the prediction paired with all forcing data at the + same timesteps and the prediction's final state which can be used as a + new initial condition. + """ + prediction, new_initial_condition = self.predict( + initial_condition, forcing, compute_derived_variables + ) + return ( + PairedData.from_batch_data( + prediction=prediction, + target=self.get_forward_data( + forcing, compute_derived_variables=compute_derived_variables + ), + ), + new_initial_condition, + ) + + def get_forward_data( + self, data: BatchData, compute_derived_variables: bool = False + ) -> BatchData: + if compute_derived_variables: + timer = GlobalTimer.get_instance() + with timer.context("compute_derived_variables"): + data = data.compute_derived_variables( + derive_func=self.derive_func, + forcing_data=data, + ) + return data.remove_initial_condition(self.n_ic_timesteps) + + def _forcing_names(self) -> List[str]: + if self.ocean is None: + return self._config.forcing_names + return list(set(self._config.forcing_names).union(self.ocean.forcing_names)) + + def train_on_batch( + self, + data: BatchData, + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> TrainOutput: + """ + Step the model forward multiple steps on a batch of data. + + Args: + data: The batch data where each tensor in data.data has shape + [n_sample, n_forward_steps + self.n_ic_timesteps, ]. + optimization: The optimization class to use for updating the module. + Use `NullOptimization` to disable training. + compute_derived_variables: Whether to compute derived variables for the + prediction and target data. + + Returns: + The loss metrics, the generated data, the normalized generated data, + and the normalized batch data. + """ + time_dim = self.TIME_DIM + + loss = torch.tensor(0.0, device=get_device()) + metrics: Dict[str, float] = {} + input_data = data.get_start(self.prognostic_names, self.n_ic_timesteps) + + optimization.set_mode(self.module) + with optimization.autocast(): + # output from self.predict does not include initial condition + output, _ = self.predict_paired( + input_data, + forcing=data, + ) + gen_data = output.prediction + target_data = output.target + n_forward_steps = output.time.shape[1] + + # compute loss for each timestep + for step in range(n_forward_steps): + # Note: here we examine the loss for a single timestep, + # not a single model call (which may contain multiple timesteps). + gen_step = {k: v.select(time_dim, step) for k, v in gen_data.items()} + target_step = { + k: v.select(time_dim, step) for k, v in target_data.items() + } + gen_norm_step = self.loss_normalizer.normalize(gen_step) + target_norm_step = self.loss_normalizer.normalize(target_step) + + step_loss = self.loss_obj(gen_norm_step, target_norm_step) + loss += step_loss + metrics[f"loss_step_{step}"] = step_loss.detach() + + loss += self._l2_sp_tuning_regularizer() + + metrics["loss"] = loss.detach() + optimization.step_weights(loss) + + stepped = TrainOutput( + metrics=metrics, + gen_data=dict(gen_data), + target_data=dict(target_data), + time=output.time, + normalize=self.normalizer.normalize, + derive_func=self.derive_func, + ) + ic = data.get_start( + set(data.data.keys()), self.n_ic_timesteps + ) # full data and not just prognostic get prepended + stepped = stepped.prepend_initial_condition(ic) + if compute_derived_variables: + stepped = stepped.compute_derived_variables() + return stepped + + def get_state(self): + """ + Returns: + The state of the stepper. + """ + return { + "module": self.module.state_dict(), + "normalizer": self.normalizer.get_state(), + "img_shape": self._img_shape, + "config": self._config.get_state(), + "gridded_operations": self._gridded_operations.to_state(), + "vertical_coordinate": self.vertical_coordinate.as_dict(), + "encoded_timestep": encode_timestep(self.timestep), + "loss_normalizer": self.loss_normalizer.get_state(), + } + + def load_state(self, state: Dict[str, Any]) -> None: + """ + Load the state of the stepper. + + Args: + state: The state to load. + """ + if "module" in state: + module = state["module"] + if "module.device_buffer" in module: + # for backwards compatibility with old checkpoints + del module["module.device_buffer"] + self.module.load_state_dict(module) + + @classmethod + def from_state(cls, state) -> "SingleModuleStepper": + """ + Load the state of the stepper. + + Args: + state: The state to load. + + Returns: + The stepper. + """ + config = {**state["config"]} # make a copy to avoid mutating input + config["normalization"] = state["normalizer"] + + # for backwards compatibility with previous steppers created w/o + # loss_normalization or residual_normalization + loss_normalizer_state = state.get("loss_normalizer", state["normalizer"]) + config["loss_normalization"] = loss_normalizer_state + + # Overwrite the residual_normalization key if it exists, since the combined + # loss scalings are saved in initial training as the loss_normalization + config["residual_normalization"] = None + + if "area" in state: + # backwards-compatibility, these older checkpoints are always lat-lon + gridded_operations: GriddedOperations = LatLonOperations(state["area"]) + else: + gridded_operations = GriddedOperations.from_state( + state["gridded_operations"] + ) + + if "sigma_coordinates" in state: + # for backwards compatibility with old checkpoints + state["vertical_coordinate"] = state["sigma_coordinates"] + + vertical_coordinate = dacite.from_dict( + data_class=HybridSigmaPressureCoordinate, + data=state["vertical_coordinate"], + config=dacite.Config(strict=True), + ) + # for backwards compatibility with original ACE checkpoint which + # serialized vertical coordinates as float64 + if vertical_coordinate.ak.dtype == torch.float64: + vertical_coordinate.ak = vertical_coordinate.ak.to(dtype=torch.float32) + if vertical_coordinate.bk.dtype == torch.float64: + vertical_coordinate.bk = vertical_coordinate.bk.to(dtype=torch.float32) + encoded_timestep = state.get("encoded_timestep", DEFAULT_ENCODED_TIMESTEP) + timestep = decode_timestep(encoded_timestep) + if "img_shape" in state: + img_shape = state["img_shape"] + else: + # this is for backwards compatibility with old checkpoints + for v in state["data_shapes"].values(): + img_shape = v[-2:] + break + derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep) + stepper = cls( + config=SingleModuleStepperConfig.from_state(config), + img_shape=img_shape, + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + derive_func=derive_func, + # don't need to initialize weights, we're about to load_state + init_weights=False, + ) + stepper.load_state(state) + return stepper diff --git a/fme/fme/core/test_stepper.py b/fme/fme/ace/test_stepper.py similarity index 56% rename from fme/fme/core/test_stepper.py rename to fme/fme/ace/test_stepper.py index b45a5ed..2ed9f6f 100644 --- a/fme/fme/core/test_stepper.py +++ b/fme/fme/ace/test_stepper.py @@ -3,46 +3,57 @@ from typing import Iterable, List, Literal, Optional, Tuple, Union from unittest.mock import MagicMock +import cftime import numpy as np import pytest import torch +import xarray as xr import fme -from fme.ace.inference.derived_variables import compute_stepped_derived_quantities +from fme.ace.aggregator import OneStepAggregator +from fme.ace.aggregator.plotting import plot_paneled_data +from fme.ace.data_loading.batch_data import BatchData, PrognosticState +from fme.ace.stepper import ( + CorrectorConfig, + SingleModuleStepper, + SingleModuleStepperConfig, + TrainOutput, + _combine_normalizers, +) from fme.core import ClimateData, metrics -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.core.coordinates import HybridSigmaPressureCoordinate from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from fme.core.loss import WeightedMappingLossConfig from fme.core.normalizer import NormalizationConfig, StandardNormalizer from fme.core.ocean import OceanConfig, SlabOceanConfig from fme.core.optimization import NullOptimization, Optimization, OptimizationConfig -from fme.core.registry import ModuleSelector -from fme.core.stepper import ( - CorrectorConfig, - SingleModuleStepper, - SingleModuleStepperConfig, - SteppedData, - _combine_normalizers, -) +from fme.core.registry.module import ModuleSelector from fme.core.typing_ import TensorDict -SphericalData = namedtuple("SphericalData", ["data", "area_weights", "sigma_coords"]) +SphericalData = namedtuple("SphericalData", ["data", "area_weights", "vertical_coord"]) TIMESTEP = datetime.timedelta(hours=6) +DEVICE = fme.get_device() def get_data(names: Iterable[str], n_samples, n_time) -> SphericalData: - data = {} + data_dict = {} n_lat, n_lon, nz = 5, 5, 7 lats = torch.linspace(-89.5, 89.5, n_lat) # arbitary choice for name in names: - data[name] = torch.rand( - n_samples, n_time, n_lat, n_lon, device=fme.get_device() - ) - area_weights = fme.spherical_area_weights(lats, n_lon).to(fme.get_device()) + data_dict[name] = torch.rand(n_samples, n_time, n_lat, n_lon, device=DEVICE) + area_weights = fme.spherical_area_weights(lats, n_lon).to(DEVICE) ak, bk = torch.arange(nz), torch.arange(nz) - sigma_coords = SigmaCoordinates(ak, bk) - return SphericalData(data, area_weights, sigma_coords) + vertical_coord = HybridSigmaPressureCoordinate(ak, bk) + data = BatchData.new_on_device( + data=data_dict, + time=xr.DataArray( + np.zeros((n_samples, n_time)), + dims=["sample", "time"], + ), + ) + return SphericalData(data, area_weights, vertical_coord) def get_scalar_data(names, value): @@ -92,42 +103,53 @@ def test_stepper_config_all_names_property( assert set(config.all_names) == set(expected_all_names) -def test_run_on_batch_normalizer_changes_only_norm_data(): +def test_train_on_batch_normalizer_changes_only_norm_data(): torch.manual_seed(0) data = get_data(["a", "b"], n_samples=5, n_time=2).data - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + area = torch.ones((5, 5), device=DEVICE) + gridded_operations = LatLonOperations(area) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + normalization_config = NormalizationConfig( + means=get_scalar_data(["a", "b"], 0.0), + stds=get_scalar_data(["a", "b"], 1.0), + ) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config={"module": torch.nn.Identity()}), in_names=["a", "b"], out_names=["a", "b"], - normalization=NormalizationConfig( - means=get_scalar_data(["a", "b"], 0.0), - stds=get_scalar_data(["a", "b"], 1.0), - ), + normalization=normalization_config, loss=WeightedMappingLossConfig(type="MSE"), ) - stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP) - stepped = stepper.run_on_batch(data=data, optimization=MagicMock()) + stepper = config.get_stepper( + (5, 5), gridded_operations, vertical_coordinate, TIMESTEP + ) + stepped = stepper.train_on_batch(data=data, optimization=MagicMock()) assert torch.allclose( - stepped.gen_data["a"], stepped.gen_data_norm["a"] + stepped.gen_data["a"], stepped.normalize(stepped.gen_data)["a"] ) # as std=1, mean=0, no change - config.normalization.stds = get_scalar_data(["a", "b"], 2.0) + normalization_config.stds = get_scalar_data(["a", "b"], 2.0) + config.normalization = normalization_config config.loss_normalization = NormalizationConfig( means=get_scalar_data(["a", "b"], 0.0), stds=get_scalar_data(["a", "b"], 3.0), ) - stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP) - stepped_double_std = stepper.run_on_batch(data=data, optimization=MagicMock()) + stepper = config.get_stepper( + (5, 5), gridded_operations, vertical_coordinate, TIMESTEP + ) + stepped_double_std = stepper.train_on_batch(data=data, optimization=MagicMock()) assert torch.allclose( stepped.gen_data["a"], stepped_double_std.gen_data["a"], rtol=1e-4 ) assert torch.allclose( - stepped.gen_data["a"], 2.0 * stepped_double_std.gen_data_norm["a"], rtol=1e-4 + stepped.gen_data["a"], + 2.0 * stepped_double_std.normalize(stepped_double_std.gen_data)["a"], + rtol=1e-4, ) assert torch.allclose( stepped.target_data["a"], - 2.0 * stepped_double_std.target_data_norm["a"], + 2.0 * stepped_double_std.normalize(stepped_double_std.target_data)["a"], rtol=1e-4, ) assert torch.allclose( @@ -135,7 +157,7 @@ def test_run_on_batch_normalizer_changes_only_norm_data(): ) # mse scales with std**2 -def test_run_on_batch_addition_series(): +def test_train_on_batch_addition_series(): torch.manual_seed(0) class AddOne(torch.nn.Module): @@ -143,9 +165,12 @@ def forward(self, x): return x + 1 n_steps = 4 - data_with_ic = get_data(["a", "b"], n_samples=5, n_time=n_steps + 1).data - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + data_with_ic: BatchData = get_data(["a", "b"], n_samples=5, n_time=n_steps + 1).data + area = torch.ones((5, 5), device=DEVICE) + gridded_operations = LatLonOperations(area) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config={"module": AddOne()}), in_names=["a", "b"], @@ -156,20 +181,21 @@ def forward(self, x): ), loss=WeightedMappingLossConfig(type="MSE"), ) - stepper = config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP) - stepped = stepper.run_on_batch( - data=data_with_ic, optimization=MagicMock(), n_forward_steps=n_steps + stepper = config.get_stepper( + (5, 5), gridded_operations, vertical_coordinate, TIMESTEP ) - # output of run_on_batch does not include the initial condition - assert stepped.gen_data["a"].shape == (5, n_steps, 5, 5) - data = {k: data_with_ic[k][:, 1:] for k in data_with_ic} + stepped = stepper.train_on_batch(data=data_with_ic, optimization=MagicMock()) + # output of train_on_batch does not include the initial condition + assert stepped.gen_data["a"].shape == (5, n_steps + 1, 5, 5) for i in range(n_steps - 1): assert torch.allclose( - stepped.gen_data_norm["a"][:, i] + 1, stepped.gen_data_norm["a"][:, i + 1] + stepped.normalize(stepped.gen_data)["a"][:, i] + 1, + stepped.normalize(stepped.gen_data)["a"][:, i + 1], ) assert torch.allclose( - stepped.gen_data_norm["b"][:, i] + 1, stepped.gen_data_norm["b"][:, i + 1] + stepped.normalize(stepped.gen_data)["b"][:, i] + 1, + stepped.normalize(stepped.gen_data)["b"][:, i + 1], ) assert torch.allclose( stepped.gen_data["a"][:, i] + 1, stepped.gen_data["a"][:, i + 1] @@ -177,11 +203,17 @@ def forward(self, x): assert torch.allclose( stepped.gen_data["b"][:, i] + 1, stepped.gen_data["b"][:, i + 1] ) - assert torch.allclose(stepped.target_data_norm["a"], data["a"]) - assert torch.allclose(stepped.target_data_norm["b"], data["b"]) + assert torch.allclose( + stepped.normalize(stepped.target_data)["a"], + data_with_ic.data["a"], + ) + assert torch.allclose( + stepped.normalize(stepped.target_data)["b"], + data_with_ic.data["b"], + ) -def test_run_on_batch_with_prescribed_ocean(): +def test_train_on_batch_with_prescribed_ocean(): torch.manual_seed(0) class AddOne(torch.nn.Module): @@ -189,45 +221,48 @@ def forward(self, x): return x + 1 n_steps = 3 - data = get_data(["a", "b", "mask"], n_samples=5, n_time=n_steps + 1).data - data["mask"] = torch.zeros_like(data["mask"], dtype=torch.int) - data["mask"][:, :, :, 0] = 1 + data: BatchData = get_data(["a", "b", "mask"], n_samples=5, n_time=n_steps + 1).data + data.data["mask"][:] = 0 + data.data["mask"][:, :, :, 0] = 1 stds = { "a": np.array([2.0], dtype=np.float32), "b": np.array([3.0], dtype=np.float32), - "mask": np.array([1.0], dtype=np.float32), } - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + area = torch.ones((5, 5), device=DEVICE) + gridded_operations = LatLonOperations(area) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config={"module": AddOne()}), in_names=["a", "b"], out_names=["a", "b"], normalization=NormalizationConfig( - means=get_scalar_data(["a", "b", "mask"], 0.0), + means=get_scalar_data(["a", "b"], 0.0), stds=stds, ), ocean=OceanConfig("b", "mask"), ) - stepper = config.get_stepper(area.shape, area, sigma_coordinates, TIMESTEP) - stepped = stepper.run_on_batch( - data, optimization=MagicMock(), n_forward_steps=n_steps + stepper = config.get_stepper( + area.shape, gridded_operations, vertical_coordinate, TIMESTEP ) + stepped = stepper.train_on_batch(data, optimization=MagicMock()) for i in range(n_steps - 1): # "a" should be increasing by 1 according to AddOne torch.testing.assert_close( - stepped.gen_data_norm["a"][:, i] + 1, stepped.gen_data_norm["a"][:, i + 1] + stepped.normalize(stepped.gen_data)["a"][:, i] + 1, + stepped.normalize(stepped.gen_data)["a"][:, i + 1], ) # "b" should be increasing by 1 where the mask says don't prescribe # note the 1: selection for the last dimension in following two assertions torch.testing.assert_close( - stepped.gen_data_norm["b"][:, i, :, 1:] + 1, - stepped.gen_data_norm["b"][:, i + 1, :, 1:], + stepped.normalize(stepped.gen_data)["b"][:, i, :, 1:] + 1, + stepped.normalize(stepped.gen_data)["b"][:, i + 1, :, 1:], ) # now check that the 0th index in last dimension has been overwritten torch.testing.assert_close( - stepped.gen_data_norm["b"][:, i, :, 0], - stepped.target_data_norm["b"][:, i, :, 0], + stepped.normalize(stepped.gen_data)["b"][:, i, :, 0], + stepped.normalize({"b": stepped.target_data["b"]})["b"][:, i, :, 0], ) @@ -245,47 +280,33 @@ def test_reloaded_stepper_gives_same_prediction(): ), ) shapes = { - "a": (1, 1, 5, 5), - "b": (1, 1, 5, 5), + "a": (1, 2, 5, 5), + "b": (1, 2, 5, 5), } - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + area = torch.ones((5, 5), device=DEVICE) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) stepper = config.get_stepper( img_shape=shapes["a"][-2:], - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) - area = torch.ones((5, 5), device=fme.get_device()) - new_stepper = SingleModuleStepper.from_state( - stepper.get_state(), area=area, sigma_coordinates=sigma_coordinates - ) + area = torch.ones((5, 5), device=DEVICE) + new_stepper = SingleModuleStepper.from_state(stepper.get_state()) data = get_data(["a", "b"], n_samples=5, n_time=2).data - first_result = stepper.run_on_batch( + first_result = stepper.train_on_batch( data=data, optimization=NullOptimization(), - n_forward_steps=1, ) - second_result = new_stepper.run_on_batch( + second_result = new_stepper.train_on_batch( data=data, optimization=NullOptimization(), - n_forward_steps=1, ) assert torch.allclose(first_result.metrics["loss"], second_result.metrics["loss"]) assert torch.allclose(first_result.gen_data["a"], second_result.gen_data["a"]) assert torch.allclose(first_result.gen_data["b"], second_result.gen_data["b"]) - assert torch.allclose( - first_result.gen_data_norm["a"], second_result.gen_data_norm["a"] - ) - assert torch.allclose( - first_result.gen_data_norm["b"], second_result.gen_data_norm["b"] - ) - assert torch.allclose( - first_result.target_data_norm["a"], second_result.target_data_norm["a"] - ) - assert torch.allclose( - first_result.target_data_norm["b"], second_result.target_data_norm["b"] - ) assert torch.allclose(first_result.target_data["a"], second_result.target_data["a"]) assert torch.allclose(first_result.target_data["b"], second_result.target_data["b"]) @@ -314,24 +335,25 @@ def forward(self, x): return zero + self._param -def _setup_and_run_on_batch( +def _setup_and_train_on_batch( data: TensorDict, in_names, out_names, ocean_config: Optional[OceanConfig], - n_forward_steps, optimization_config: Optional[OptimizationConfig], ): - """Sets up the requisite classes to run run_on_batch.""" + """Sets up the requisite classes to run train_on_batch.""" module = ReturnZerosModule(len(in_names), len(out_names)) if optimization_config is None: optimization: Union[NullOptimization, Optimization] = NullOptimization() else: - optimization = optimization_config.build(module.parameters(), 2) + optimization = optimization_config.build(modules=[module], max_epochs=2) - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + area = torch.ones((5, 5), device=DEVICE) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config={"module": module}), in_names=in_names, @@ -342,10 +364,10 @@ def _setup_and_run_on_batch( ), ocean=ocean_config, ) - stepper = config.get_stepper(area.shape, area, sigma_coordinates, TIMESTEP) - return stepper.run_on_batch( - data, optimization=optimization, n_forward_steps=n_forward_steps + stepper = config.get_stepper( + area.shape, LatLonOperations(area), vertical_coordinate, TIMESTEP ) + return stepper.train_on_batch(data, optimization=optimization) @pytest.mark.parametrize( @@ -358,7 +380,7 @@ def _setup_and_run_on_batch( ) @pytest.mark.parametrize("n_forward_steps", [1, 2, 3], ids=lambda p: f"k={p}") @pytest.mark.parametrize("is_train", [True, False], ids=["is_train", ""]) -def test_run_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescribed): +def test_train_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescribed): in_names, out_names = ["a"], ["a"] if is_input: in_names.append("b") @@ -374,15 +396,59 @@ def test_run_on_batch(n_forward_steps, is_input, is_output, is_train, is_prescri else: ocean_config = None - data, area_weights, sigma_coords = get_data(all_names, 3, n_forward_steps + 1) + data, _, _ = get_data(all_names, 3, n_forward_steps + 1) if is_train: optimization = OptimizationConfig() else: optimization = None - _setup_and_run_on_batch( - data, in_names, out_names, ocean_config, n_forward_steps, optimization + _setup_and_train_on_batch(data, in_names, out_names, ocean_config, optimization) + + +@pytest.mark.parametrize("n_forward_steps", [1, 2, 3]) +def test_train_on_batch_one_step_aggregator(n_forward_steps): + in_names, out_names, all_names = ["a"], ["a"], ["a"] + data, _, _ = get_data(all_names, 3, n_forward_steps + 1) + stepper = _get_stepper(in_names, out_names, ocean_config=None, module_name="AddOne") + + aggregator = OneStepAggregator( + gridded_operations=LatLonOperations(torch.ones((5, 5))), + ) + + stepped = stepper.train_on_batch(data, optimization=NullOptimization()) + assert stepped.gen_data["a"].shape[1] == n_forward_steps + 1 + + aggregator.record_batch(stepped) + logs = aggregator.get_logs("one_step") + + gen = data.data["a"].select(dim=1, index=0) + 1 + tar = data.data["a"].select(dim=1, index=1) + + bias = torch.mean(gen - tar) + assert np.isclose(bias.item(), logs["one_step/mean/weighted_bias/a"]) + + residual_gen = torch.ones((5, 5)) + residual_tar = tar[0] - data.data["a"].select(dim=1, index=0)[0] + residual_imgs = [[residual_gen.cpu().numpy()], [residual_tar.cpu().numpy()]] + residual_plot = plot_paneled_data(residual_imgs, diverging=True) + assert np.allclose( + residual_plot.to_data_array(), + logs["one_step/snapshot/image-residual/a"].to_data_array(), + ) + + full_field_gen = gen.mean(dim=0) + full_field_tar = tar.mean(dim=0) + full_field_plot = plot_paneled_data( + [ + [full_field_gen.cpu().numpy()], + [full_field_tar.cpu().numpy()], + ], + diverging=False, + ) + assert np.allclose( + full_field_plot.to_data_array(), + logs["one_step/mean_map/image-full-field/a"].to_data_array(), ) @@ -406,24 +472,28 @@ def forward(self, x): (False, "advection_and_precipitation", True), ], ) -def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: bool): +@pytest.mark.parametrize("compute_derived_in_train_on_batch", [False, True]) +def test_stepper_corrector( + global_only: bool, + terms_to_modify, + force_positive: bool, + compute_derived_in_train_on_batch: bool, +): torch.random.manual_seed(0) n_forward_steps = 5 device = get_device() data = { - "PRESsfc": 10.0 + torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device), + "PRESsfc": 10.0 + torch.rand(size=(3, n_forward_steps + 1, 5, 5)), "specific_total_water_0": -0.2 - + torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device), - "specific_total_water_1": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to( - device - ), - "PRATEsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device), - "LHTFLsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)).to(device), + + torch.rand(size=(3, n_forward_steps + 1, 5, 5)), + "specific_total_water_1": torch.rand(size=(3, n_forward_steps + 1, 5, 5)), + "PRATEsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)), + "LHTFLsfc": torch.rand(size=(3, n_forward_steps + 1, 5, 5)), "tendency_of_total_water_path_due_to_advection": torch.rand( size=(3, n_forward_steps + 1, 5, 5) - ).to(device), + ), } - sigma_coordinates = SigmaCoordinates( + vertical_coordinate = HybridSigmaPressureCoordinate( ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0]) ).to(device) area_weights = 1.0 + torch.rand(size=(5, 5)).to(device) @@ -441,14 +511,12 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b ) mean_advection = metrics.weighted_mean( - data["tendency_of_total_water_path_due_to_advection"], + data["tendency_of_total_water_path_due_to_advection"].to(device), weights=area_weights, dim=[-2, -1], ) assert (mean_advection.abs() > 0.0).all() - # use a randomly initialized Linear layer for the module - # using PrebuiltBuilder stepper_config = SingleModuleStepperConfig( builder=ModuleSelector( type="prebuilt", @@ -466,22 +534,35 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b ) stepper = stepper_config.get_stepper( img_shape=data["PRESsfc"].shape[2:], - area=area_weights, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area_weights), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) + time = xr.DataArray( + [ + [ + cftime.DatetimeProlepticGregorian( + 2000, 1, int(i * 6 // 24) + 1, i * 6 % 24 + ) + for i in range(n_forward_steps + 1) + ] + for _ in range(3) + ], + dims=["sample", "time"], + ) + batch_data = BatchData.new_on_cpu( + data=data, + time=time, + ).to_device() # run the stepper on the data with torch.no_grad(): - stepped = stepper.run_on_batch( - data=data, + stepped = stepper.train_on_batch( + data=batch_data, optimization=NullOptimization(), - n_forward_steps=n_forward_steps, + compute_derived_variables=compute_derived_in_train_on_batch, ) - - stepped = compute_stepped_derived_quantities( - stepped, sigma_coordinates=sigma_coordinates, timestep=TIMESTEP - ) - + if not compute_derived_in_train_on_batch: + stepped = stepped.compute_derived_variables() # check that the budget residual is zero budget_residual = stepped.gen_data["total_water_path_budget_residual"] if global_only: @@ -516,7 +597,7 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b dry_air = ( metrics.weighted_mean( ClimateData(stepped.gen_data).surface_pressure_due_to_dry_air( - sigma_coordinates + vertical_coordinate ), weights=area_weights, dim=[-2, -1], @@ -530,7 +611,7 @@ def test_stepper_corrector(global_only: bool, terms_to_modify, force_positive: b # check that positive forcing is enforced if force_positive: for name in force_positive_names: - assert stepped.gen_data[name].min() >= 0.0 + assert stepped.gen_data[name][:, 1:].min() >= 0.0 def _get_stepper( @@ -569,7 +650,9 @@ def forward(self, x): all_names = list(set(in_names + out_names)) area = torch.ones((5, 5)) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) config = SingleModuleStepperConfig( builder=ModuleSelector(type="prebuilt", config=module_config), in_names=in_names, @@ -581,12 +664,14 @@ def forward(self, x): ocean=ocean_config, **kwargs, ) - return config.get_stepper((5, 5), area, sigma_coordinates, TIMESTEP) + return config.get_stepper( + (5, 5), LatLonOperations(area), vertical_coordinate, TIMESTEP + ) def test_step(): stepper = _get_stepper(["a", "b"], ["a", "b"]) - input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]} + input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]} output = stepper.step(input_data, {}) @@ -596,7 +681,7 @@ def test_step(): def test_step_with_diagnostic(): stepper = _get_stepper(["a"], ["a", "c"], module_name="RepeatChannel") - input_data = {"a": torch.rand(3, 5, 5)} + input_data = {"a": torch.rand(3, 5, 5).to(DEVICE)} output = stepper.step(input_data, {}) torch.testing.assert_close(output["a"], input_data["a"]) torch.testing.assert_close(output["c"], input_data["a"]) @@ -604,7 +689,7 @@ def test_step_with_diagnostic(): def test_step_with_forcing_and_diagnostic(): stepper = _get_stepper(["a", "b"], ["a", "c"]) - input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]} + input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]} output = stepper.step(input_data, {}) torch.testing.assert_close(output["a"], input_data["a"] + 1) assert "b" not in output @@ -615,8 +700,8 @@ def test_step_with_prescribed_ocean(): stepper = _get_stepper( ["a", "b"], ["a", "b"], ocean_config=OceanConfig("a", "mask") ) - input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b", "mask"]} - ocean_data = {x: torch.rand(3, 5, 5) for x in ["a", "mask"]} + input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]} + ocean_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "mask"]} output = stepper.step(input_data, ocean_data) expected_a_output = torch.where( torch.round(ocean_data["mask"]).to(int) == 1, @@ -628,51 +713,105 @@ def test_step_with_prescribed_ocean(): assert set(output) == {"a", "b"} +def get_data_for_predict( + n_steps, forcing_names: List[str] +) -> Tuple[PrognosticState, BatchData]: + n_samples = 3 + input_data = BatchData.new_on_device( + data={"a": torch.rand(n_samples, 1, 5, 5).to(DEVICE)}, + time=xr.DataArray( + np.zeros((n_samples, 1)), + dims=["sample", "time"], + ), + ).get_start( + prognostic_names=["a"], + n_ic_timesteps=1, + ) + forcing_data = BatchData.new_on_device( + data={ + name: torch.rand(3, n_steps + 1, 5, 5).to(DEVICE) for name in forcing_names + }, + time=xr.DataArray( + np.zeros((n_samples, n_steps + 1)), + dims=["sample", "time"], + ), + ) + return input_data, forcing_data + + def test_predict(): - stepper = _get_stepper(["a", "b"], ["a", "b"]) + stepper = _get_stepper(["a"], ["a"]) n_steps = 3 - input_data = {x: torch.rand(3, 5, 5) for x in ["a", "b"]} - forcing_data = {} - output = stepper.predict(input_data, forcing_data, n_steps) - for variable in ["a", "b"]: - assert output[variable].size(dim=1) == n_steps - torch.testing.assert_close( - output[variable][:, -1], input_data[variable] + n_steps - ) + input_data, forcing_data = get_data_for_predict(n_steps, forcing_names=[]) + forcing_data.data = {} + output, new_input_data = stepper.predict(input_data, forcing_data) + xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time) + variable = "a" + assert output.data[variable].size(dim=1) == n_steps + torch.testing.assert_close( + output.data[variable][:, -1], + input_data.as_batch_data().data[variable][:, 0] + n_steps, + ) + assert isinstance(new_input_data, PrognosticState) + new_input_state = new_input_data.as_batch_data() + assert isinstance(new_input_state, BatchData) + torch.testing.assert_close( + new_input_state.data[variable][:, 0], output.data[variable][:, -1] + ) + assert new_input_state.time.equals(output.time[:, -1:]) def test_predict_with_forcing(): stepper = _get_stepper(["a", "b"], ["a"], module_name="ChannelSum") n_steps = 3 - input_data = {"a": torch.rand(3, 5, 5)} - forcing_data = {"b": torch.rand(3, n_steps + 1, 5, 5)} - output = stepper.predict(input_data, forcing_data, n_steps) - assert "b" not in output - assert output["a"].size(dim=1) == n_steps + input_data, forcing_data = get_data_for_predict(n_steps, forcing_names=["b"]) + output, new_input_data = stepper.predict(input_data, forcing_data) + assert "b" not in output.data + assert output.data["a"].size(dim=1) == n_steps + xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time) torch.testing.assert_close( - output["a"][:, 0], input_data["a"] + forcing_data["b"][:, 0] - ) + output.data["a"][:, 0], + input_data.as_batch_data().data["a"][:, 0] + forcing_data.data["b"][:, 0], + ) + assert isinstance(new_input_data, PrognosticState) + new_input_state = new_input_data.as_batch_data() + assert isinstance(new_input_state, BatchData) + torch.testing.assert_close(new_input_state.data["a"][:, 0], output.data["a"][:, -1]) + assert "b" not in new_input_state.data for n in range(1, n_steps): - expected_a_output = output["a"][:, n - 1] + forcing_data["b"][:, n] - torch.testing.assert_close(output["a"][:, n], expected_a_output) + expected_a_output = output.data["a"][:, n - 1] + forcing_data.data["b"][:, n] + torch.testing.assert_close(output.data["a"][:, n], expected_a_output) + xr.testing.assert_equal(output.time, forcing_data.time[:, 1:]) + assert new_input_state.time.equals(output.time[:, -1:]) def test_predict_with_ocean(): stepper = _get_stepper(["a"], ["a"], ocean_config=OceanConfig("a", "mask")) n_steps = 3 - input_data = {"a": torch.rand(3, 5, 5)} - forcing_data = {x: torch.rand(3, n_steps + 1, 5, 5) for x in ["a", "mask"]} - output = stepper.predict(input_data, forcing_data, n_steps) - assert "mask" not in output - assert output["a"].size(dim=1) == n_steps + input_data, forcing_data = get_data_for_predict( + n_steps, forcing_names=["a", "mask"] + ) + output, new_input_data = stepper.predict(input_data, forcing_data) + xr.testing.assert_allclose(forcing_data.time[:, 1:], output.time) + assert "mask" not in output.data + assert output.data["a"].size(dim=1) == n_steps for n in range(n_steps): - previous_a = input_data["a"] if n == 0 else output["a"][:, n - 1] + previous_a = ( + input_data.as_batch_data().data["a"][:, 0] + if n == 0 + else output.data["a"][:, n - 1] + ) expected_a_output = torch.where( - torch.round(forcing_data["mask"][:, n + 1]).to(int) == 1, - forcing_data["a"][:, n + 1], + torch.round(forcing_data.data["mask"][:, n + 1]).to(int) == 1, + forcing_data.data["a"][:, n + 1], previous_a + 1, ) - torch.testing.assert_close(output["a"][:, n], expected_a_output) + torch.testing.assert_close(output.data["a"][:, n], expected_a_output) + assert isinstance(new_input_data, PrognosticState) + new_input_state = new_input_data.as_batch_data() + assert isinstance(new_input_state, BatchData) + torch.testing.assert_close(new_input_state.data["a"][:, 0], output.data["a"][:, -1]) + assert new_input_state.time.equals(output.time[:, -1:]) def test_next_step_forcing_names(): @@ -682,39 +821,46 @@ def test_next_step_forcing_names(): module_name="ChannelSum", next_step_forcing_names=["c"], ) - input_data = {x: torch.rand(1, 5, 5) for x in ["a"]} - forcing_data = {x: torch.rand(1, 2, 5, 5) for x in ["b", "c"]} - stepper.predict(input_data, forcing_data, 1) + input_data, forcing_data = get_data_for_predict(n_steps=1, forcing_names=["b", "c"]) + stepper.predict(input_data, forcing_data) torch.testing.assert_close( - stepper.module.module.last_input[:, 1, :], forcing_data["b"][:, 0] + stepper.module.module.last_input[:, 1, :], forcing_data.data["b"][:, 0] ) torch.testing.assert_close( - stepper.module.module.last_input[:, 2, :], forcing_data["c"][:, 1] + stepper.module.module.last_input[:, 2, :], forcing_data.data["c"][:, 1] ) def test_prepend_initial_condition(): nt = 3 - x = torch.rand(3, nt, 5).to(fme.get_device()) - x_normed = (x - x.mean()) / x.std() - stepped = SteppedData( + x = torch.rand(3, nt, 5).to(DEVICE) + + def normalize(x): + result = {k: (v - 1) / 2 for k, v in x.items()} + return result + + stepped = TrainOutput( gen_data={"a": x, "b": x + 1}, - gen_data_norm={"a": x_normed, "b": x_normed + 1}, - target_data={"a": x, "b": x + 1}, - target_data_norm={"a": x_normed, "b": x_normed + 1}, + target_data={"a": x + 2, "b": x + 3}, + time=xr.DataArray(np.zeros((3, nt)), dims=["sample", "time"]), metrics={"loss": torch.tensor(0.0)}, + normalize=normalize, ) - ic = { - "a": torch.rand(3, 5).to(fme.get_device()), - "b": torch.rand(3, 5).to(fme.get_device()), + ic_data = { + "a": torch.rand(3, 1, 5).to(DEVICE), + "b": torch.rand(3, 1, 5).to(DEVICE), } - ic_normed = {k: (v - v.mean()) / v.std() for k, v in ic.items()} - prepended = stepped.prepend_initial_condition(ic, ic_normed) + ic = BatchData.new_on_device( + data=ic_data, + time=xr.DataArray(np.zeros((3, 1)), dims=["sample", "time"]), + ).get_start( + prognostic_names=["a", "b"], + n_ic_timesteps=1, + ) + prepended = stepped.prepend_initial_condition(ic) for v in ["a", "b"]: - assert torch.allclose(prepended.gen_data[v][:, 0], ic[v]) - assert torch.allclose(prepended.gen_data_norm[v][:, 0], ic_normed[v]) - assert torch.allclose(prepended.target_data[v][:, 0], ic[v]) - assert torch.allclose(prepended.target_data_norm[v][:, 0], ic_normed[v]) + assert torch.allclose(prepended.gen_data[v][:, :1], ic_data[v]) + assert torch.allclose(prepended.target_data[v][:, :1], ic_data[v]) def test__combine_normalizers(): @@ -752,62 +898,48 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer(): # stepper loaded from state should have the appropriately combined # full field and residual values in its loss_normalizer torch.manual_seed(0) - full_field_normalization = { - "means": {"a": 0.0, "b": 0.0, "diagnostic": 0.0}, - "stds": {"a": 1.0, "b": 1.0, "diagnostic": 1.0}, - } + full_field_means = {"a": 0.0, "b": 0.0, "diagnostic": 0.0} + full_field_stds = {"a": 1.0, "b": 1.0, "diagnostic": 1.0} # residual scalings might have diagnostic variables but the stepper # should detect which prognostic variables to use from the set - residual_normalization = { - "means": {"a": 1.0, "b": 1.0, "diagnostic": 1.0}, - "stds": {"a": 2.0, "b": 2.0, "diagnostic": 2.0}, - } + residual_means = {"a": 1.0, "b": 1.0, "diagnostic": 1.0} + residual_stds = {"a": 2.0, "b": 2.0, "diagnostic": 2.0} config = SingleModuleStepperConfig( builder=ModuleSelector( type="SphericalFourierNeuralOperatorNet", config={"scale_factor": 1} ), in_names=["a", "b"], out_names=["a", "b", "diagnostic"], - normalization=NormalizationConfig(**full_field_normalization), - residual_normalization=NormalizationConfig(**residual_normalization), + normalization=NormalizationConfig(means=full_field_means, stds=full_field_stds), + residual_normalization=NormalizationConfig( + means=residual_means, stds=residual_stds + ), ) shapes = { "a": (1, 1, 5, 5), "b": (1, 1, 5, 5), "diagnostic": (1, 1, 5, 5), } - area = torch.ones((5, 5), device=fme.get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)) + area = torch.ones((5, 5), device=DEVICE) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) orig_stepper = config.get_stepper( img_shape=shapes["a"][-2:], - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) - stepper_from_state = SingleModuleStepper.from_state( - orig_stepper.get_state(), area=area, sigma_coordinates=sigma_coordinates - ) + stepper_from_state = SingleModuleStepper.from_state(orig_stepper.get_state()) for stepper in [orig_stepper, stepper_from_state]: - assert stepper.loss_normalizer.means == {"a": 1.0, "b": 1.0, "diagnostic": 0.0} - assert stepper.loss_normalizer.stds == {"a": 2.0, "b": 2.0, "diagnostic": 1.0} - assert stepper.normalizer.means == full_field_normalization["means"] - assert stepper.normalizer.stds == full_field_normalization["stds"] - - -def test_stepper_effective_loss_scaling(): - custom_loss_weights = {"b": 2.0} - loss_norm_means = {"a": 0.0, "b": 0.0} - loss_norm_stds = {"a": 4.0, "b": 0.5} - stepper = _get_stepper( - in_names=["a", "b"], - out_names=["a", "b"], - loss=WeightedMappingLossConfig(weights=custom_loss_weights), - loss_normalization=NormalizationConfig( - means=loss_norm_means, stds=loss_norm_stds - ), - ) - assert stepper.effective_loss_scaling == { - "a": torch.tensor(4.0), - "b": torch.tensor(0.25), - } + assert stepper.loss_normalizer.means == { + **residual_means, + "diagnostic": full_field_means["diagnostic"], + } + assert stepper.loss_normalizer.stds == { + **residual_stds, + "diagnostic": full_field_stds["diagnostic"], + } + assert stepper.normalizer.means == full_field_means + assert stepper.normalizer.stds == full_field_stds diff --git a/fme/fme/ace/test_train.py b/fme/fme/ace/test_train.py index 6b2327d..e62c0ee 100755 --- a/fme/fme/ace/test_train.py +++ b/fme/fme/ace/test_train.py @@ -1,6 +1,9 @@ +import copy +import dataclasses import pathlib import subprocess import tempfile +import textwrap import unittest.mock import numpy as np @@ -10,17 +13,34 @@ import yaml from fme.ace.inference.evaluator import main as inference_evaluator_main -from fme.ace.train.train import _restore_checkpoint, count_parameters -from fme.ace.train.train import main as train_main -from fme.ace.train.train_config import epoch_checkpoint_enabled -from fme.core.data_loading.config import Slice -from fme.core.testing import ( +from fme.ace.registry.test_hpx import ( + conv_next_block_config, + decoder_config, + down_sampling_block_config, + encoder_config, + output_layer_config, + recurrent_block_config, + up_sampling_block_config, +) +from fme.ace.testing import ( DimSizes, MonthlyReferenceData, - save_2d_netcdf, + save_nd_netcdf, save_scalar_netcdf, ) +from fme.ace.train.train import main as train_main +from fme.core.coordinates import ( + HEALPixCoordinates, + HorizontalCoordinates, + LatLonCoordinates, +) +from fme.core.generics.trainer import ( + _restore_checkpoint, + count_parameters, + epoch_checkpoint_enabled, +) from fme.core.testing.wandb import mock_wandb +from fme.core.typing_ import Slice REPOSITORY_PATH = pathlib.PurePath(__file__).parent.parent.parent.parent JOB_SUBMISSION_SCRIPT_PATH = ( @@ -46,7 +66,67 @@ def _get_test_yaml_files( max_epochs=1, segment_epochs=1, inference_forward_steps=2, + use_healpix=False, ): + input_time_size = 1 + output_time_size = 1 + if use_healpix: + in_channels = len(in_variable_names) + out_channels = len(out_variable_names) + prognostic_variables = min( + out_channels, in_channels + ) # how many variables in/out share. + # in practice, we will need to compare variable names, since there + # are some input-only and some output-only channels. + # TODO: https://github.com/ai2cm/full-model/issues/1046 + n_constants = 0 + decoder_input_channels = 0 # was 1, to indicate insolation - now 0 + input_time_size = 1 # TODO: change to 2 (issue #1177) + output_time_size = 1 # TODO: change to 4 (issue #1177) + spatial_dimensions_str = "healpix" + + conv_next_block = conv_next_block_config(in_channels=in_channels) + down_sampling_block = down_sampling_block_config() + recurrent_block = recurrent_block_config() + encoder = encoder_config( + conv_next_block, down_sampling_block, n_channels=[16, 8, 4] + ) + up_sampling_block = up_sampling_block_config() + output_layer = output_layer_config() + decoder = decoder_config( + conv_next_block, + up_sampling_block, + output_layer, + recurrent_block, + n_channels=[4, 8, 16], + ) + + # Need to manually indent these YAML strings + encoder_yaml = yaml.dump( + dataclasses.asdict(encoder), default_flow_style=False, indent=2 + ) + decoder_yaml = yaml.dump( + dataclasses.asdict(decoder), default_flow_style=False, indent=2 + ) + encoder_yaml_indented = textwrap.indent(encoder_yaml, " ") + decoder_yaml_indented = textwrap.indent(decoder_yaml, " ") + config_str = f""" + encoder: +{encoder_yaml_indented} + decoder: +{decoder_yaml_indented} + prognostic_variables: {prognostic_variables} + n_constants: {n_constants} + decoder_input_channels: {decoder_input_channels} + input_time_size: {input_time_size} + output_time_size: {output_time_size} + """ + else: + config_str = """ + num_layers: 2 + embed_dim: 12""" + spatial_dimensions_str = "latlon" + new_stepper_config = f""" in_names: {in_variable_names} out_names: {out_variable_names} @@ -57,12 +137,10 @@ def _get_test_yaml_files( global_means_path: '{global_means_path}' global_stds_path: '{global_stds_path}' loss: - global_mean_type: "LpLoss" + type: "MSE" builder: type: {nettype} - config: - num_layers: 2 - embed_dim: 12 + config: {config_str} ocean: surface_temperature_name: {in_variable_names[0]} ocean_fraction_name: {mask_name} @@ -80,17 +158,18 @@ def _get_test_yaml_files( train_loader: dataset: - data_path: '{train_data_path}' + spatial_dimensions: {spatial_dimensions_str} batch_size: 2 - num_data_workers: 1 + num_data_workers: 0 validation_loader: dataset: - data_path: '{valid_data_path}' + spatial_dimensions: {spatial_dimensions_str} batch_size: 2 - num_data_workers: 1 + num_data_workers: 0 optimization: optimizer_type: "Adam" lr: 0.001 - enable_automatic_mixed_precision: true scheduler: type: CosineAnnealingLR kwargs: @@ -102,7 +181,8 @@ def _get_test_yaml_files( monthly_reference_data: {monthly_data_filename} loader: dataset: - data_path: '{valid_data_path}' + data_path: '{valid_data_path}' + spatial_dimensions: {spatial_dimensions_str} start_indices: first: 0 n_initial_conditions: 2 @@ -139,6 +219,7 @@ def _get_test_yaml_files( loader: dataset: data_path: '{valid_data_path}' + spatial_dimensions: {spatial_dimensions_str} start_indices: first: 0 n_initial_conditions: 2 @@ -156,6 +237,23 @@ def _get_test_yaml_files( return f_train.name, f_inference.name +def get_sizes( + spatial_dims: HorizontalCoordinates = LatLonCoordinates( + lon=torch.Tensor(np.arange(32)), + lat=torch.Tensor(np.arange(16)), + loaded_lat_name="grid_yt", + loaded_lon_name="grid_xt", + ), + n_time=3, + nz_interface=3, +) -> DimSizes: + return DimSizes( + n_time=n_time, + horizontal=copy.deepcopy(spatial_dims.loaded_sizes), + nz_interface=nz_interface, + ) + + def _setup( path, nettype, @@ -165,22 +263,31 @@ def _setup( n_time=10, timestep_days=5, inference_forward_steps=2, + use_healpix=False, ): if not path.exists(): path.mkdir() seed = 0 np.random.seed(seed) - in_variable_names = ["foo", "bar", "baz"] - out_variable_names = ["foo", "bar"] + in_variable_names = [ + "PRESsfc", + "specific_total_water_0", + "specific_total_water_1", + "baz", + ] + out_variable_names = ["PRESsfc", "specific_total_water_0", "specific_total_water_1"] mask_name = "mask" all_variable_names = list(set(in_variable_names + out_variable_names)) - dim_sizes = DimSizes( - n_time=n_time, - n_lat=16, - n_lon=32, - nz_interface=7, - ) + if use_healpix: + hpx_coords = HEALPixCoordinates( + face=torch.Tensor(np.arange(12)), + width=torch.Tensor(np.arange(16)), + height=torch.Tensor(np.arange(16)), + ) + dim_sizes = get_sizes(spatial_dims=hpx_coords, n_time=n_time) + else: + dim_sizes = get_sizes(n_time=n_time) data_dir = path / "data" stats_dir = path / "stats" @@ -188,7 +295,7 @@ def _setup( data_dir.mkdir() stats_dir.mkdir() results_dir.mkdir() - save_2d_netcdf( + save_nd_netcdf( data_dir / "data.nc", dim_sizes, variable_names=all_variable_names + [mask_name], @@ -203,15 +310,22 @@ def _setup( variable_names=all_variable_names, ) + monthly_dim_sizes: DimSizes + if use_healpix: + hpx_coords = HEALPixCoordinates( + face=torch.Tensor(np.arange(12)), + width=torch.Tensor(np.arange(16)), + height=torch.Tensor(np.arange(16)), + ) + monthly_dim_sizes = get_sizes( + spatial_dims=hpx_coords, n_time=10 * 12, nz_interface=1 + ) + else: + monthly_dim_sizes = get_sizes(n_time=10 * 12, nz_interface=1) monthly_reference_data = MonthlyReferenceData( path=data_dir, names=out_variable_names, - dim_sizes=DimSizes( - n_time=10 * 12, - n_lat=16, - n_lon=32, - nz_interface=1, - ), + dim_sizes=monthly_dim_sizes, n_ensemble=3, ) @@ -230,61 +344,84 @@ def _setup( max_epochs=max_epochs, segment_epochs=segment_epochs, inference_forward_steps=inference_forward_steps, + use_healpix=use_healpix, ) return train_config_filename, inference_config_filename @pytest.mark.parametrize( - "nettype", ["SphericalFourierNeuralOperatorNet", "SFNO-v0.1.0"] + "nettype", ["SphericalFourierNeuralOperatorNet", "HEALPixRecUNet", "SFNO-v0.1.0"] ) -def test_train_and_inference_inline(tmp_path, nettype): +def test_train_and_inference_inline(tmp_path, nettype, very_fast_only: bool): """Make sure that training and inference run without errors Args: tmp_path: pytext fixture for temporary workspace. nettype: parameter indicating model architecture to use. - debug: option for developers to allow use of pdb. + very_fast_only: parameter indicating whether to skip slow tests. """ + if very_fast_only: + pytest.skip("Skipping non-fast tests") # need multi-year to cover annual aggregator train_config, inference_config = _setup( tmp_path, nettype, + log_to_wandb=True, timestep_days=20, n_time=int(366 * 3 / 20 + 1), inference_forward_steps=int(366 * 3 / 20 / 2 - 1) * 2, # must be even + use_healpix=(nettype == "HEALPixRecUNet"), ) # using pdb requires calling main functions directly - train_main( - yaml_config=train_config, - ) + with mock_wandb() as wandb: + train_main( + yaml_config=train_config, + ) + wandb_logs = wandb.get_logs() + + for log in wandb_logs: + # ensure inference time series is not logged + assert "inference/mean/forecast_step" not in log + # inference should not require stats files (tmp_path / "stats" / "stats-mean.nc").unlink() (tmp_path / "stats" / "stats-stddev.nc").unlink() - inference_logs = inference_evaluator_main(yaml_config=inference_config) - assert len(inference_logs) == 7 # 6 forward steps + 1 initial state + + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True) + inference_evaluator_main(yaml_config=inference_config) + inference_logs = wandb.get_logs() + prediction_output_path = tmp_path / "output" / "autoregressive_predictions.nc" - assert prediction_output_path.exists() best_checkpoint_path = ( tmp_path / "output" / "training_checkpoints" / "best_ckpt.tar" ) - assert best_checkpoint_path.exists() best_inference_checkpoint_path = ( tmp_path / "output" / "training_checkpoints" / "best_inference_ckpt.tar" ) + assert best_checkpoint_path.exists() assert best_inference_checkpoint_path.exists() + n_ic_timesteps = 1 + n_forward_steps = 6 + n_summary_steps = 1 + assert len(inference_logs) == n_ic_timesteps + n_forward_steps + n_summary_steps + assert prediction_output_path.exists() ds_prediction = xr.open_dataset(prediction_output_path) - assert np.sum(np.isnan(ds_prediction["foo"].values)) == 0 - assert np.sum(np.isnan(ds_prediction["bar"].values)) == 0 + assert np.sum(np.isnan(ds_prediction["PRESsfc"].values)) == 0 + assert np.sum(np.isnan(ds_prediction["specific_total_water_0"].values)) == 0 + assert np.sum(np.isnan(ds_prediction["specific_total_water_1"].values)) == 0 ds_target = xr.open_dataset(tmp_path / "output" / "autoregressive_target.nc") assert np.sum(np.isnan(ds_target["baz"].values)) == 0 @pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"]) -def test_resume(tmp_path, nettype): +def test_resume(tmp_path, nettype, very_fast_only: bool): """Make sure the training is resumed from a checkpoint when restarted.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") mock = unittest.mock.MagicMock(side_effect=_restore_checkpoint) - with unittest.mock.patch("fme.ace.train.train._restore_checkpoint", new=mock): + with unittest.mock.patch("fme.core.generics.trainer._restore_checkpoint", new=mock): train_config, _ = _setup( tmp_path, nettype, log_to_wandb=True, max_epochs=2, segment_epochs=1 ) @@ -292,28 +429,16 @@ def test_resume(tmp_path, nettype): train_main( yaml_config=train_config, ) - assert ( - min([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val]) - == 0 - ) - assert ( - max([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val]) - == 0 - ) + assert min([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 0 + assert max([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 0 assert not mock.called with mock_wandb() as wandb: train_main( yaml_config=train_config, ) mock.assert_called() - assert ( - min([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val]) - == 1 - ) - assert ( - max([val["epoch"] for val in wandb.get_logs().values() if "epoch" in val]) - == 1 - ) + assert min([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 1 + assert max([val["epoch"] for val in wandb.get_logs() if "epoch" in val]) == 1 @pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"]) @@ -349,29 +474,33 @@ def _create_fine_tuning_config(path_to_train_config_yaml: str, path_to_checkpoin with open(path_to_train_config_yaml, "r") as config_file: config_data = yaml.safe_load(config_file) config_data["stepper"] = {"checkpoint_path": path_to_checkpoint} + current_experiment_dir = config_data["experiment_dir"] + new_experiment_dir = pathlib.Path(current_experiment_dir) / "fine_tuning" + config_data["experiment_dir"] = str(new_experiment_dir) with tempfile.NamedTemporaryFile( mode="w", delete=False, suffix=".yaml" ) as new_config_file: new_config_file.write(yaml.dump(config_data)) - return new_config_file.name + return new_config_file.name, new_experiment_dir @pytest.mark.parametrize("nettype", ["SphericalFourierNeuralOperatorNet"]) -def test_fine_tuning(tmp_path, nettype): +def test_fine_tuning(tmp_path, nettype, very_fast_only: bool): """Check that fine tuning config runs without errors.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") train_config, _ = _setup(tmp_path, nettype) - train_main( - yaml_config=train_config, - ) + train_main(yaml_config=train_config) results_dir = tmp_path / "output" ckpt = f"{results_dir}/training_checkpoints/best_ckpt.tar" - fine_tuning_config = _create_fine_tuning_config(train_config, ckpt) + fine_tuning_config, new_results_dir = _create_fine_tuning_config(train_config, ckpt) train_main(yaml_config=fine_tuning_config) + assert (new_results_dir / "training_checkpoints" / "ckpt.tar").exists() def _create_copy_weights_after_batch_config( diff --git a/fme/fme/ace/testing/__init__.py b/fme/fme/ace/testing/__init__.py new file mode 100644 index 0000000..034a41a --- /dev/null +++ b/fme/fme/ace/testing/__init__.py @@ -0,0 +1,9 @@ +from .fv3gfs_data import ( + DimSize, + DimSizes, + FV3GFSData, + MonthlyReferenceData, + StatsData, + save_nd_netcdf, + save_scalar_netcdf, +) diff --git a/fme/fme/core/testing/fv3gfs_data.py b/fme/fme/ace/testing/fv3gfs_data.py similarity index 71% rename from fme/fme/core/testing/fv3gfs_data.py rename to fme/fme/ace/testing/fv3gfs_data.py index fede51b..3b20d42 100644 --- a/fme/fme/core/testing/fv3gfs_data.py +++ b/fme/fme/ace/testing/fv3gfs_data.py @@ -1,17 +1,18 @@ import dataclasses import datetime import pathlib -from typing import List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import cftime import numpy as np import xarray as xr -from fme.core.data_loading.config import XarrayDataConfig -from fme.core.data_loading.inference import ( +from fme.ace.data_loading.inference import ( InferenceDataLoaderConfig, InferenceInitialConditionIndices, ) +from fme.core.coordinates import DimSize +from fme.core.dataset.config import XarrayDataConfig def _coord_value( @@ -30,51 +31,59 @@ def _coord_value( @dataclasses.dataclass class DimSizes: n_time: int - n_lat: int - n_lon: int + horizontal: List[DimSize] nz_interface: int @property - def shape_2d(self) -> Tuple[int, int, int]: - return (self.n_time, self.n_lat, self.n_lon) + def shape_nd(self) -> List[int]: + return [self.n_time] + [dim.size for dim in self.horizontal] @property - def dims_2d(self) -> Tuple[str, str, str]: - return ("time", "grid_yt", "grid_xt") + def dims_nd(self) -> List[str]: + return ["time"] + [dim.name for dim in self.horizontal] @property def shape_vertical_interface(self) -> Tuple[int]: return (self.nz_interface,) + @property def items(self): - return [ - ("time", self.n_time), - ("grid_yt", self.n_lat), - ("grid_xt", self.n_lon), + return [("time", self.n_time)] + [ + (dim.name, dim.size) for dim in self.horizontal ] + def get_size(self, key: str) -> int: + for item in self.horizontal: + if item.name == key: + return item.size + raise KeyError(f"Dimension with name '{key}' not found.") -def save_2d_netcdf( + +def save_nd_netcdf( filename, dim_sizes: DimSizes, variable_names: List[str], timestep_days: float, time_varying_values: Optional[List[float]] = None, + save_vertical_coordinate: bool = True, ): """ - Save a 2D netcdf file with random data for the given variable names and + Save a ND netcdf file with random data for the given variable names and dimensions. Args: filename: The filename to save the netcdf file to. dim_sizes: The dimensions of the data. variable_names: The names of the variables to save. + timestep_days: The number of days between each time step. time_varying_values: If not None, the values to use for each time step. + save_vertical_coordinate: If True, save vertical coordinate variables. """ - ds = get_2d_dataset( + ds = get_nd_dataset( dim_sizes=dim_sizes, variable_names=variable_names, timestep_days=timestep_days, + include_vertical_coordinate=save_vertical_coordinate, ) if time_varying_values is not None: for name in variable_names: @@ -122,6 +131,8 @@ class FV3GFSData: dim_sizes: DimSizes timestep_days: float time_varying_values: Optional[List[float]] = None + num_data_workers: int = 0 + save_vertical_coordinate: bool = True def __post_init__(self): self.data_path.mkdir(parents=True, exist_ok=True) @@ -133,12 +144,13 @@ def __post_init__(self): f"Number of time-varying values ({len(self.time_varying_values)}) " f"must match number of time steps ({self.dim_sizes.n_time})" ) - save_2d_netcdf( - self._data_filename, + save_nd_netcdf( + self.data_filename, dim_sizes=self.dim_sizes, variable_names=self.names, timestep_days=self.timestep_days, time_varying_values=self.time_varying_values, + save_vertical_coordinate=self.save_vertical_coordinate, ) @property @@ -147,7 +159,7 @@ def data_path(self): return self.path / "data" @property - def _data_filename(self): + def data_filename(self): return self.data_path / "data.nc" @property @@ -159,7 +171,7 @@ def inference_data_loader_config(self) -> InferenceDataLoaderConfig: start_indices=InferenceInitialConditionIndices( first=0, n_initial_conditions=1, interval=1 ), - num_data_workers=2, + num_data_workers=self.num_data_workers, ) @@ -172,7 +184,7 @@ class MonthlyReferenceData: def __post_init__(self): self.data_path.mkdir(parents=True, exist_ok=True) - ds = get_2d_dataset( + ds = get_nd_dataset( dim_sizes=self.dim_sizes, variable_names=self.names, ) @@ -207,33 +219,43 @@ def data_filename(self): return self.data_path / "monthly.nc" -def get_2d_dataset( +def get_nd_overrides(dim_sizes: DimSizes, timestep_days: float) -> dict[str, Any]: + coords_override: dict[str, Any] + time = [ + cftime.DatetimeProlepticGregorian(2000, 1, 1) + + i * timestep_days * datetime.timedelta(days=1) + for i in range(dim_sizes.n_time) + ] + coords_override = {"time": time} + if "grid_yt" in dim_sizes.dims_nd: + n_lat = dim_sizes.get_size("grid_yt") # n_lat + n_lon = dim_sizes.get_size("grid_xt") # n_lon + grid_yt = np.linspace(-89.5, 89.5, n_lat) + grid_xt_start = 360.0 / n_lon / 2 + grid_xt = np.linspace(grid_xt_start, 360.0 - grid_xt_start, n_lon) + coords_override["grid_yt"] = grid_yt + coords_override["grid_xt"] = grid_xt + return coords_override + + +def get_nd_dataset( dim_sizes: DimSizes, variable_names: Sequence[str], timestep_days: float = 1.0, + include_vertical_coordinate: bool = True, ): """ - Gets a dataset of [time, lat, lon] data. + Gets a dataset of [time, ] data. """ data_vars = {} for name in variable_names: - data = np.random.randn(*dim_sizes.shape_2d).astype(np.float32) + data = np.random.randn(*dim_sizes.shape_nd).astype(np.float32) data_vars[name] = xr.DataArray( data, - dims=["time", "grid_yt", "grid_xt"], + dims=dim_sizes.dims_nd, attrs={"units": "m", "long_name": name}, ) - - grid_yt = np.linspace(-89.5, 89.5, dim_sizes.n_lat) - grid_xt_start = 360.0 / dim_sizes.n_lon / 2 - grid_xt = np.linspace(grid_xt_start, 360.0 - grid_xt_start, dim_sizes.n_lon) - time = [ - cftime.DatetimeProlepticGregorian(2000, 1, 1) - + i * timestep_days * datetime.timedelta(days=1) - for i in range(dim_sizes.n_time) - ] - - coords_override = {"grid_yt": grid_yt, "grid_xt": grid_xt, "time": time} + coords_override = get_nd_overrides(dim_sizes=dim_sizes, timestep_days=timestep_days) coords = { dim_name: ( @@ -244,12 +266,13 @@ def get_2d_dataset( if dim_name not in coords_override else coords_override[dim_name] ) - for dim_name, size in dim_sizes.items() + for dim_name, size in dim_sizes.items } - for i in range(dim_sizes.nz_interface): - data_vars[f"ak_{i}"] = np.float64(i) - data_vars[f"bk_{i}"] = np.float64(i + 1) + if include_vertical_coordinate: + for i in range(dim_sizes.nz_interface): + data_vars[f"ak_{i}"] = np.float64(i) + data_vars[f"bk_{i}"] = np.float64(i + 1) ds = xr.Dataset(data_vars=data_vars, coords=coords) return ds diff --git a/fme/fme/ace/train/__init__.py b/fme/fme/ace/train/__init__.py index 35d2319..d95e567 100644 --- a/fme/fme/ace/train/__init__.py +++ b/fme/fme/ace/train/__init__.py @@ -1 +1,2 @@ -from fme.ace.train.train import Trainer, _restore_checkpoint, count_parameters +from fme.ace.train.train import Trainer +from fme.core.generics.trainer import count_parameters diff --git a/fme/fme/ace/train/train.py b/fme/fme/ace/train/train.py index aecf943..15355ec 100644 --- a/fme/fme/ace/train/train.py +++ b/fme/fme/ace/train/train.py @@ -48,484 +48,171 @@ # Karthik Kashinath - NVIDIA Corporation # Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation -import contextlib import dataclasses -import gc import logging import os -import time -import uuid -from typing import Optional +from datetime import timedelta +from typing import Callable, Dict, Mapping, Optional, Sequence import dacite import dask import torch +import xarray as xr import yaml import fme import fme.core.logging_utils as logging_utils -from fme.ace.inference import run_inference_evaluator -from fme.ace.inference.derived_variables import compute_stepped_derived_quantities -from fme.ace.train.train_config import TrainConfig -from fme.core.aggregator import ( +from fme.ace.aggregator import OneStepAggregator, TrainAggregator +from fme.ace.aggregator.inference.main import ( + InferenceEvaluatorAggregator, InferenceEvaluatorAggregatorConfig, - OneStepAggregator, - TrainAggregator, ) -from fme.core.data_loading.getters import get_data_loader, get_inference_data -from fme.core.data_loading.utils import BatchData +from fme.ace.data_loading.batch_data import PairedData, PrognosticState +from fme.ace.stepper import TrainOutput +from fme.ace.train.train_config import TrainBuilders, TrainConfig +from fme.core.coordinates import HorizontalCoordinates, HybridSigmaPressureCoordinate +from fme.core.dataset.data_typing import VariableMetadata +from fme.core.dicts import to_flat_dict from fme.core.distributed import Distributed -from fme.core.ema import EMATracker -from fme.core.optimization import NullOptimization -from fme.core.stepper import SingleModuleStepper -from fme.core.wandb import WandB +from fme.core.generics.trainer import AggregatorBuilderABC, TrainConfigProtocol, Trainer +from fme.core.gridded_ops import GriddedOperations +from fme.core.typing_ import TensorDict, TensorMapping # dask used on individual workers to load batches dask.config.set(scheduler="synchronous") -def count_parameters(modules: torch.nn.ModuleList) -> int: - parameters = 0 - for module in modules: - for parameter in module.parameters(): - if parameter.requires_grad: - parameters += parameter.numel() - return parameters +def build_trainer(builder: TrainBuilders, config: TrainConfig) -> "Trainer": + # note for devs: you don't have to use this function to build a custom + # trainer, you can build it however you like. This is here for convenience. + train_data = builder.get_train_data() + validation_data = builder.get_validation_data() + inference_data = builder.get_evaluation_inference_data() - -class Trainer: - def __init__(self, config: TrainConfig): - logging.info(f"Current device is {fme.get_device()}") - self.dist = Distributed.get_instance() - if self.dist.is_root(): - if not os.path.isdir(config.experiment_dir): - os.makedirs(config.experiment_dir) - if not os.path.isdir(config.checkpoint_dir): - os.makedirs(config.checkpoint_dir) - self.config = config - - data_requirements = config.stepper.get_data_requirements( - n_forward_steps=self.config.n_forward_steps - ) - logging.info("rank %d, begin data loader init" % self.dist.rank) - self.train_data = get_data_loader( - config.train_loader, - requirements=data_requirements, - train=True, - ) - self.valid_data = get_data_loader( - config.validation_loader, - requirements=data_requirements, - train=False, - ) - logging.info("rank %d, data loader initialized" % self.dist.rank) - for gridded_data, name in zip( - (self.train_data, self.valid_data), ("train", "valid") - ): - n_samples = len(gridded_data.loader.dataset) - n_batches = len(gridded_data.loader) - logging.info(f"{name} data: {n_samples} samples, {n_batches} batches") - first_time = gridded_data.loader.dataset[0][1].values[0] - last_time = gridded_data.loader.dataset[-1][1].values[0] - logging.info(f"{name} data: first sample's initial time: {first_time}") - logging.info(f"{name} data: last sample's initial time: {last_time}") - - self.num_batches_seen = 0 - self.startEpoch = 0 - self._model_epoch = self.startEpoch - self.num_batches_seen = 0 - self._best_validation_loss = torch.inf - self._best_inference_error = torch.inf - - for batch in self.train_data.loader: - shapes = {k: v.shape for k, v in batch.data.items()} - for value in shapes.values(): - img_shape = value[-2:] - break + for batch in train_data.loader: + shapes = {k: v.shape for k, v in batch.data.items()} + for value in shapes.values(): + img_shape = value[-2:] break - logging.info("Starting model initialization") - self.stepper = config.stepper.get_stepper( - img_shape=img_shape, - area=self.train_data.area_weights, - sigma_coordinates=self.train_data.sigma_coordinates, - timestep=self.train_data.timestep, - ) - self.optimization = config.optimization.build( - self.stepper.module.parameters(), config.max_epochs - ) - self._base_weights = self.config.stepper.get_base_weights() - self._copy_after_batch = config.copy_weights_after_batch - self._no_optimization = NullOptimization() - - if config.resuming: - logging.info("Loading checkpoint %s" % config.latest_checkpoint_path) - self.restore_checkpoint( - config.latest_checkpoint_path, config.ema_checkpoint_path - ) - - wandb = WandB.get_instance() - wandb.watch(self.stepper.modules) - - logging.info( - ( - "Number of trainable model parameters: " - f"{count_parameters(self.stepper.modules)}" - ) - ) - inference_data_requirements = dataclasses.replace(data_requirements) - inference_data_requirements.n_timesteps = config.inference.n_forward_steps + 1 - - self._inference_data = get_inference_data( - config.inference.loader, - config.inference.forward_steps_in_memory, - inference_data_requirements, - ) - - self._ema = self.config.ema.build(self.stepper.modules) - - def switch_off_grad(self, model): - for param in model.parameters(): - param.requires_grad = False - - def train(self): - logging.info("Starting Training Loop...") - - self._model_epoch = self.startEpoch - inference_epochs = list(range(0, self.config.max_epochs))[ - self.config.inference.epochs.slice - ] - if self.config.segment_epochs is None: - segment_max_epochs = self.config.max_epochs - else: - segment_max_epochs = min( - self.startEpoch + self.config.segment_epochs, self.config.max_epochs - ) - # "epoch" describes the loop, self._model_epoch describes model weights - # needed so we can describe the loop even after weights are updated - for epoch in range(self.startEpoch, segment_max_epochs): - # garbage collect to avoid CUDA error in some contexts - # https://github.com/pytorch/pytorch/issues/67978#issuecomment-1661986812 # noqa: E501 - gc.collect() - logging.info(f"Epoch: {epoch+1}") - if isinstance(self.train_data.sampler, torch.utils.data.DistributedSampler): - self.train_data.sampler.set_epoch(epoch) - - start_time = time.time() - logging.info(f"Starting training step on epoch {epoch + 1}") - train_logs = self.train_one_epoch() - train_end = time.time() - logging.info(f"Starting validation step on epoch {epoch + 1}") - valid_logs = self.validate_one_epoch() - valid_end = time.time() - if epoch in inference_epochs: - logging.info(f"Starting inference step on epoch {epoch + 1}") - inference_logs = self.inference_one_epoch() - inference_end: Optional[float] = time.time() - else: - inference_logs = {} - inference_end = None - - train_loss = train_logs["train/mean/loss"] - valid_loss = valid_logs["val/mean/loss"] - inference_error = inference_logs.get( - "inference/time_mean_norm/rmse/channel_mean", None - ) - # need to get the learning rate before stepping the scheduler - lr = self.optimization.learning_rate - self.optimization.step_scheduler(valid_loss) - - if self.dist.is_root(): - if self.config.save_checkpoint: - logging.info(f"Saving checkpoints for epoch {epoch + 1}") - self.save_all_checkpoints(valid_loss, inference_error) - - time_elapsed = time.time() - start_time - logging.info(f"Time taken for epoch {epoch + 1} is {time_elapsed} sec") - logging.info(f"Train loss: {train_loss}. Valid loss: {valid_loss}") - if inference_error is not None: - logging.info(f"Inference error: {inference_error}") - - logging.info("Logging to wandb") - all_logs = { - **train_logs, - **valid_logs, - **inference_logs, - **{ - "lr": lr, - "epoch": epoch, - "epoch_train_seconds": train_end - start_time, - "epoch_validation_seconds": valid_end - train_end, - "epoch_total_seconds": time_elapsed, - }, - } - if inference_end is not None: - all_logs["epoch_inference_seconds"] = inference_end - valid_end - wandb = WandB.get_instance() - wandb.log(all_logs, step=self.num_batches_seen) - if segment_max_epochs == self.config.max_epochs: - self.config.clean_wandb() - - def train_one_epoch(self): - """Train for one epoch and return logs from TrainAggregator.""" - wandb = WandB.get_instance() - aggregator = TrainAggregator() - if self.num_batches_seen == 0: - # Before training, log the loss on the first batch. - with torch.no_grad(): - batch = next(iter(self.train_data.loader)) - stepped = self.stepper.run_on_batch( - batch.data, - optimization=self._no_optimization, - n_forward_steps=self.config.n_forward_steps, - ) - - if self.config.log_train_every_n_batches > 0: - with torch.no_grad(): - metrics = { - f"batch_{name}": self.dist.reduce_mean(metric) - for name, metric in sorted(stepped.metrics.items()) - } - wandb.log(metrics, step=self.num_batches_seen) - batch: BatchData - current_time = time.time() - for batch in self.train_data.loader: - stepped = self.stepper.run_on_batch( - batch.data, - self.optimization, - n_forward_steps=self.config.n_forward_steps, - ) - aggregator.record_batch(stepped.metrics["loss"]) - if self._base_weights is not None: - self._copy_after_batch.apply( - weights=self._base_weights, modules=self.stepper.modules - ) - self._ema(model=self.stepper.modules) - self.num_batches_seen += 1 - if ( - self.config.log_train_every_n_batches > 0 - and self.num_batches_seen % self.config.log_train_every_n_batches == 0 - ): - with torch.no_grad(): - metrics = { - f"batch_{name}": self.dist.reduce_mean(metric) - for name, metric in sorted(stepped.metrics.items()) - } - duration = time.time() - current_time - current_time = time.time() - n_samples = ( - self.train_data.loader.batch_size - * self.config.log_train_every_n_batches - ) - samples_per_second = n_samples / duration - metrics["training_samples_per_second"] = samples_per_second - wandb.log(metrics, step=self.num_batches_seen) - self._model_epoch += 1 - - return aggregator.get_logs(label="train") - - @contextlib.contextmanager - def _validation_context(self): - """ - The context for running validation. - - In this context, the stepper uses the EMA model if - `self.config.validate_using_ema` is True. - """ - if self.config.validate_using_ema: - with self._ema_context(): - yield - else: - yield - - @contextlib.contextmanager - def _ema_context(self): - """ - A context where the stepper uses the EMA model. - """ - self._ema.store(parameters=self.stepper.modules.parameters()) - self._ema.copy_to(model=self.stepper.modules) - try: - yield - finally: - self._ema.restore(parameters=self.stepper.modules.parameters()) - - def validate_one_epoch(self): - aggregator = OneStepAggregator( - self.train_data.area_weights.to(fme.get_device()), - self.train_data.sigma_coordinates, - self.train_data.metadata, - loss_scaling=self.stepper.effective_loss_scaling, - ) - - with torch.no_grad(), self._validation_context(): - for batch in self.valid_data.loader: - stepped = self.stepper.run_on_batch( - batch.data, - optimization=NullOptimization(), - n_forward_steps=self.config.n_forward_steps, - ) - # Prepend initial condition back to start of windows - # as it's used to compute differenced quantities - ic, normed_ic = self.stepper.get_initial_condition(batch.data) - stepped = stepped.prepend_initial_condition(ic, normed_ic) - - stepped = compute_stepped_derived_quantities( - stepped, - self.valid_data.sigma_coordinates, - self.valid_data.timestep, - forcing_data=stepped.target_data, - ) - aggregator.record_batch( - loss=stepped.metrics["loss"], - target_data=stepped.target_data, - gen_data=stepped.gen_data, - target_data_norm=stepped.target_data_norm, - gen_data_norm=stepped.gen_data_norm, - ) - return aggregator.get_logs(label="val") + break + logging.info("Starting model initialization") + stepper = builder.get_stepper( + img_shape=img_shape, + gridded_operations=train_data.gridded_operations, + vertical_coordinate=train_data.vertical_coordinate, + timestep=train_data.timestep, + ) + end_of_batch_ops = builder.get_end_of_batch_ops(stepper.modules) + + for batch in inference_data.loader: + initial_inference_times = batch.time.isel(time=0) + break + aggregator_builder = AggregatorBuilder( + inference_config=config.inference_aggregator, + gridded_operations=train_data.gridded_operations, + vertical_coordinate=train_data.vertical_coordinate, + horizontal_coordinates=train_data.horizontal_coordinates, + timestep=train_data.timestep, + initial_inference_time=initial_inference_times, + record_step_20=config.inference_n_forward_steps >= 20, + n_timesteps=config.inference_n_forward_steps + stepper.n_ic_timesteps, + variable_metadata=train_data.variable_metadata, + loss_scaling=stepper.effective_loss_scaling, + channel_mean_names=stepper.out_names, + normalize=stepper.normalizer.normalize, + ) + do_gc_collect = fme.get_device() != torch.device("cpu") + trainer_config: TrainConfigProtocol = config # documenting trainer input type + return Trainer( + train_data=train_data, + validation_data=validation_data, + inference_data=inference_data, + stepper=stepper, + build_optimization=builder.get_optimization, + build_ema=builder.get_ema, + config=trainer_config, + aggregator_builder=aggregator_builder, + end_of_batch_callback=end_of_batch_ops, + do_gc_collect=do_gc_collect, + ) - def inference_one_epoch(self): - record_step_20 = self.config.inference.n_forward_steps >= 20 - aggregator_config: InferenceEvaluatorAggregatorConfig = ( - self.config.inference.aggregator - ) - for batch in self._inference_data.loader: - initial_times = batch.times.isel(time=0) - break - aggregator = aggregator_config.build( - area_weights=self.train_data.area_weights.to(fme.get_device()), - sigma_coordinates=self.train_data.sigma_coordinates, - timestep=self.train_data.timestep, - initial_times=initial_times, - record_step_20=record_step_20, - n_timesteps=self.config.inference.n_forward_steps + 1, - metadata=self.train_data.metadata, - data_grid=self.train_data.grid, - ) - with torch.no_grad(), self._validation_context(): - run_inference_evaluator( - aggregator=aggregator, - stepper=self.stepper, - data=self._inference_data, - ) - logs = aggregator.get_logs(label="inference") - if "inference/mean/series" in logs: - # Tables don't work well when reported every epoch, this is a quick - # workaround to remove them. Could refactor to avoid returning - # at all, but it's used when converting the logs to epoch-wise - # wandb logs in standalone inference. - logs.pop("inference/mean/series") - if "inference/mean_norm/series" in logs: - logs.pop("inference/mean_norm/series") - return logs - def save_checkpoint(self, checkpoint_path): - # save to a temporary file in case we get pre-empted during save - temporary_location = os.path.join( - os.path.dirname(checkpoint_path), f".{uuid.uuid4()}.tmp" +class AggregatorBuilder( + AggregatorBuilderABC[PrognosticState, TrainOutput, PairedData], +): + def __init__( + self, + inference_config: InferenceEvaluatorAggregatorConfig, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + horizontal_coordinates: HorizontalCoordinates, + timestep: timedelta, + initial_inference_time: xr.DataArray, + record_step_20: bool, + n_timesteps: int, + normalize: Callable[[TensorMapping], TensorDict], + variable_metadata: Optional[Mapping[str, VariableMetadata]] = None, + loss_scaling: Optional[Dict[str, torch.Tensor]] = None, + channel_mean_names: Optional[Sequence[str]] = None, + ): + self.inference_config = inference_config + self.gridded_operations = gridded_operations + self.vertical_coordinate = vertical_coordinate + self.horizontal_coordinates = horizontal_coordinates + self.timestep = timestep + self.initial_inference_time = initial_inference_time + self.record_step_20 = record_step_20 + self.n_timesteps = n_timesteps + self.variable_metadata = variable_metadata + self.loss_scaling = loss_scaling + self.channel_mean_names = channel_mean_names + self.normalize = normalize + + def get_train_aggregator(self) -> TrainAggregator: + return TrainAggregator() + + def get_validation_aggregator(self) -> OneStepAggregator: + return OneStepAggregator( + gridded_operations=self.gridded_operations, + variable_metadata=self.variable_metadata, + loss_scaling=self.loss_scaling, ) - try: - torch.save( - { - "num_batches_seen": self.num_batches_seen, - "epoch": self._model_epoch, - "best_validation_loss": self._best_validation_loss, - "best_inference_error": self._best_inference_error, - "stepper": self.stepper.get_state(), - "optimization": self.optimization.get_state(), - "ema": self._ema.get_state(), - }, - temporary_location, - ) - os.replace(temporary_location, checkpoint_path) - finally: - if os.path.exists(temporary_location): - os.remove(temporary_location) - def restore_checkpoint(self, checkpoint_path, ema_checkpoint_path): - _restore_checkpoint(self, checkpoint_path, ema_checkpoint_path) - - def save_all_checkpoints(self, valid_loss: float, inference_error: Optional[float]): - logging.info( - f"Saving latest checkpoint to {self.config.latest_checkpoint_path}" + def get_inference_aggregator( + self, + ) -> InferenceEvaluatorAggregator: + return self.inference_config.build( + vertical_coordinate=self.vertical_coordinate, + horizontal_coordinates=self.horizontal_coordinates, + timestep=self.timestep, + initial_time=self.initial_inference_time, + record_step_20=self.record_step_20, + n_timesteps=self.n_timesteps, + variable_metadata=self.variable_metadata, + channel_mean_names=self.channel_mean_names, + normalize=self.normalize, ) - self.save_checkpoint(self.config.latest_checkpoint_path) - if self.config.epoch_checkpoint_enabled(self._model_epoch): - epoch_checkpoint_path = self.config.epoch_checkpoint_path(self._model_epoch) - logging.info(f"Saving epoch checkpoint to {epoch_checkpoint_path}") - self.save_checkpoint(epoch_checkpoint_path) - if self.config.ema_epoch_checkpoint_enabled(self._model_epoch): - ema_epoch_checkpoint_path = self.config.ema_epoch_checkpoint_path( - self._model_epoch - ) - logging.info(f"Saving EMA epoch checkpoint to {ema_epoch_checkpoint_path}") - with self._ema_context(): - self.save_checkpoint(ema_epoch_checkpoint_path) - if self.config.validate_using_ema: - best_checkpoint_context = self._ema_context - else: - best_checkpoint_context = contextlib.nullcontext # type: ignore - with best_checkpoint_context(): - if valid_loss <= self._best_validation_loss: - logging.info( - "Saving lowest validation loss checkpoint to " - f"{self.config.best_checkpoint_path}" - ) - self._best_validation_loss = valid_loss - self.save_checkpoint(self.config.best_checkpoint_path) - if inference_error is not None and ( - inference_error <= self._best_inference_error - ): - logging.info( - f"Epoch inference error ({inference_error}) is lower than " - f"previous best inference error ({self._best_inference_error})." - ) - logging.info( - "Saving lowest inference error checkpoint to " - f"{self.config.best_inference_checkpoint_path}" - ) - self._best_inference_error = inference_error - self.save_checkpoint(self.config.best_inference_checkpoint_path) - with self._ema_context(): - logging.info( - f"Saving latest EMA checkpoint to {self.config.ema_checkpoint_path}" - ) - self.save_checkpoint(self.config.ema_checkpoint_path) -def _restore_checkpoint(trainer: Trainer, checkpoint_path, ema_checkpoint_path): - # separated into a function only to make it easier to mock - checkpoint = torch.load(checkpoint_path, map_location=fme.get_device()) - # restore checkpoint is used for finetuning as well as resuming. - # If finetuning (i.e., not resuming), restore checkpoint - # does not load optimizer state, instead uses config specified lr. - trainer.stepper.load_state(checkpoint["stepper"]) - trainer.optimization.load_state(checkpoint["optimization"]) - trainer.num_batches_seen = checkpoint["num_batches_seen"] - trainer.startEpoch = checkpoint["epoch"] - trainer._best_validation_loss = checkpoint["best_validation_loss"] - trainer._best_inference_error = checkpoint["best_inference_error"] - ema_checkpoint = torch.load(ema_checkpoint_path, map_location=fme.get_device()) - ema_stepper: SingleModuleStepper = SingleModuleStepper.from_state( - ema_checkpoint["stepper"], - area=trainer.train_data.area_weights, - sigma_coordinates=trainer.train_data.sigma_coordinates, - ) - trainer._ema = EMATracker.from_state(checkpoint["ema"], ema_stepper.modules) +def run_train_from_config(config: TrainConfig): + run_train(TrainBuilders(config), config) -def run_train_from_config(config: TrainConfig): +def run_train(builders: TrainBuilders, config: TrainConfig): dist = Distributed.get_instance() if fme.using_gpu(): torch.backends.cudnn.benchmark = True if not os.path.isdir(config.experiment_dir): os.makedirs(config.experiment_dir, exist_ok=True) - config.configure_logging(log_filename="out.log") + config.logging.configure_logging(config.experiment_dir, log_filename="out.log") env_vars = logging_utils.retrieve_env_vars() logging_utils.log_versions() beaker_url = logging_utils.log_beaker_url() - config.configure_wandb(env_vars=env_vars, resume=True, notes=beaker_url) - trainer = Trainer(config) + config_as_dict = to_flat_dict(dataclasses.asdict(config)) + config.logging.configure_wandb( + config=config_as_dict, env_vars=env_vars, resume=True, notes=beaker_url + ) + trainer = build_trainer(builders, config) trainer.train() logging.info("DONE ---- rank %d" % dist.rank) diff --git a/fme/fme/ace/train/train_config.py b/fme/fme/ace/train/train_config.py index 1781b84..1d3e990 100644 --- a/fme/fme/ace/train/train_config.py +++ b/fme/fme/ace/train/train_config.py @@ -1,24 +1,37 @@ import dataclasses -import logging +import datetime import os -from typing import Any, Dict, Optional, Union - -from fme.core.aggregator import InferenceEvaluatorAggregatorConfig -from fme.core.data_loading.config import DataLoaderConfig, Slice -from fme.core.data_loading.inference import InferenceDataLoaderConfig -from fme.core.dicts import to_flat_dict +from typing import List, Optional, Tuple, Union + +import torch + +from fme.ace.aggregator import InferenceEvaluatorAggregatorConfig +from fme.ace.data_loading.batch_data import GriddedData, InferenceGriddedData +from fme.ace.data_loading.config import DataLoaderConfig +from fme.ace.data_loading.getters import get_data_loader, get_inference_data +from fme.ace.data_loading.inference import InferenceDataLoaderConfig +from fme.ace.requirements import PrognosticStateDataRequirements +from fme.ace.stepper import ( + ExistingStepperConfig, + SingleModuleStepper, + SingleModuleStepperConfig, +) +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.dataset.requirements import DataRequirements from fme.core.distributed import Distributed -from fme.core.ema import EMAConfig +from fme.core.ema import EMAConfig, EMATracker +from fme.core.generics.trainer import EndOfBatchCallback +from fme.core.gridded_ops import GriddedOperations from fme.core.logging_utils import LoggingConfig -from fme.core.optimization import OptimizationConfig -from fme.core.stepper import ExistingStepperConfig, SingleModuleStepperConfig +from fme.core.optimization import Optimization, OptimizationConfig +from fme.core.typing_ import Slice from fme.core.weight_ops import CopyWeightsConfig @dataclasses.dataclass class InlineInferenceConfig: """ - Attributes: + Parameters: loader: configuration for the data loader used during inference n_forward_steps: number of forward steps to take forward_steps_in_memory: number of forward steps to take before @@ -34,7 +47,9 @@ class InlineInferenceConfig: forward_steps_in_memory: int = 2 epochs: Slice = Slice(start=0, stop=None, step=1) aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field( - default_factory=lambda: InferenceEvaluatorAggregatorConfig() + default_factory=lambda: InferenceEvaluatorAggregatorConfig( + log_global_mean_time_series=False, log_global_mean_norm_time_series=False + ) ) def __post_init__(self): @@ -46,6 +61,14 @@ def __post_init__(self): f"{self.loader.start_indices.n_initial_conditions} and " f"{dist.world_size}." ) + if ( + self.aggregator.log_global_mean_time_series + or self.aggregator.log_global_mean_norm_time_series + ): + # Both of log_global_mean_time_series and + # log_global_mean_norm_time_series must be False for inline inference. + self.aggregator.log_global_mean_time_series = False + self.aggregator.log_global_mean_norm_time_series = False @dataclasses.dataclass @@ -53,7 +76,7 @@ class TrainConfig: """ Configuration for training a model. - Attributes: + Arguments: train_loader: Configuration for the training data loader. validation_loader: Configuration for the validation data loader. stepper: Configuration for the stepper. @@ -103,6 +126,14 @@ class TrainConfig: log_train_every_n_batches: int = 100 segment_epochs: Optional[int] = None + @property + def inference_n_forward_steps(self) -> int: + return self.inference.n_forward_steps + + @property + def inference_aggregator(self) -> InferenceEvaluatorAggregatorConfig: + return self.inference.aggregator + @property def checkpoint_dir(self) -> str: """ @@ -110,63 +141,83 @@ def checkpoint_dir(self) -> str: """ return os.path.join(self.experiment_dir, "training_checkpoints") - @property - def latest_checkpoint_path(self) -> str: - return os.path.join(self.checkpoint_dir, "ckpt.tar") + def clean_wandb(self, experiment_dir: str) -> None: + self.logging.clean_wandb(experiment_dir=experiment_dir) - @property - def best_checkpoint_path(self) -> str: - return os.path.join(self.checkpoint_dir, "best_ckpt.tar") + def get_inference_epochs(self) -> List[int]: + return list(range(0, self.max_epochs))[self.inference.epochs.slice] - @property - def best_inference_checkpoint_path(self) -> str: - return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar") - @property - def ema_checkpoint_path(self) -> str: - return os.path.join(self.checkpoint_dir, "ema_ckpt.tar") +class TrainBuilders: + def __init__(self, config: TrainConfig): + self.config = config - def epoch_checkpoint_path(self, epoch: int) -> str: - return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar") - - def ema_epoch_checkpoint_path(self, epoch: int) -> str: - return os.path.join(self.checkpoint_dir, f"ema_ckpt_{epoch:04d}.tar") - - def epoch_checkpoint_enabled(self, epoch: int) -> bool: - return epoch_checkpoint_enabled( - epoch, self.max_epochs, self.checkpoint_save_epochs + def _get_train_window_data_requirements(self) -> DataRequirements: + return self.config.stepper.get_evaluation_window_data_requirements( + self.config.n_forward_steps ) - def ema_epoch_checkpoint_enabled(self, epoch: int) -> bool: - return epoch_checkpoint_enabled( - epoch, self.max_epochs, self.ema_checkpoint_save_epochs + def _get_evaluation_window_data_requirements(self) -> DataRequirements: + return self.config.stepper.get_evaluation_window_data_requirements( + self.config.inference.forward_steps_in_memory ) - @property - def resuming(self) -> bool: - checkpoint_file_exists = os.path.isfile(self.latest_checkpoint_path) - resuming = True if checkpoint_file_exists else False - return resuming - - def configure_logging(self, log_filename: str): - self.logging.configure_logging(self.experiment_dir, log_filename) - - def configure_wandb(self, env_vars: Optional[Dict[str, Any]] = None, **kwargs): - config = to_flat_dict(dataclasses.asdict(self)) - self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs) + def _get_initial_condition_data_requirements( + self, + ) -> PrognosticStateDataRequirements: + return self.config.stepper.get_prognostic_state_data_requirements() + + def get_train_data(self) -> GriddedData: + data_requirements = self._get_train_window_data_requirements() + return get_data_loader( + self.config.train_loader, + requirements=data_requirements, + train=True, + ) - def log(self): - logging.info("------------------ Configuration ------------------") - logging.info(str(self)) - logging.info("---------------------------------------------------") + def get_validation_data(self) -> GriddedData: + data_requirements = self._get_train_window_data_requirements() + return get_data_loader( + self.config.validation_loader, + requirements=data_requirements, + train=False, + ) - def clean_wandb(self): - self.logging.clean_wandb(experiment_dir=self.experiment_dir) + def get_evaluation_inference_data( + self, + ) -> InferenceGriddedData: + return get_inference_data( + config=self.config.inference.loader, + total_forward_steps=self.config.inference_n_forward_steps, + window_requirements=self._get_evaluation_window_data_requirements(), + initial_condition=self._get_initial_condition_data_requirements(), + ) + def get_optimization(self, modules: torch.nn.ModuleList) -> Optimization: + return self.config.optimization.build(modules, self.config.max_epochs) + + def get_stepper( + self, + img_shape: Tuple[int, int], + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ) -> SingleModuleStepper: + return self.config.stepper.get_stepper( + img_shape=img_shape, + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + ) -def epoch_checkpoint_enabled( - epoch: int, max_epochs: int, save_epochs: Optional[Slice] -) -> bool: - if save_epochs is None: - return False - return epoch in range(max_epochs)[save_epochs.slice] + def get_ema(self, modules) -> EMATracker: + return self.config.ema.build(modules) + + def get_end_of_batch_ops( + self, modules: List[torch.nn.Module] + ) -> EndOfBatchCallback: + base_weights = self.config.stepper.get_base_weights() + if base_weights is not None: + copy_after_batch = self.config.copy_weights_after_batch + return lambda: copy_after_batch.apply(weights=base_weights, modules=modules) + return lambda: None diff --git a/fme/fme/ace/validate_config.py b/fme/fme/ace/validate_config.py index 0fab210..6fba79f 100644 --- a/fme/fme/ace/validate_config.py +++ b/fme/fme/ace/validate_config.py @@ -7,8 +7,8 @@ from fme.ace.inference.evaluator import InferenceEvaluatorConfig from fme.ace.inference.inference import InferenceConfig +from fme.ace.stepper import SingleModuleStepperConfig from fme.ace.train.train_config import TrainConfig -from fme.core.stepper import SingleModuleStepperConfig CONFIG_CHOICES = ["train", "inference", "evaluator"] diff --git a/fme/fme/core/__init__.py b/fme/fme/core/__init__.py index 766547d..b7f028a 100644 --- a/fme/fme/core/__init__.py +++ b/fme/fme/core/__init__.py @@ -1,5 +1,6 @@ from .climate_data import ClimateData from .device import get_device, using_gpu +from .gridded_ops import GriddedOperations from .metrics import ( root_mean_squared_error, spherical_area_weights, @@ -8,7 +9,6 @@ ) from .normalizer import StandardNormalizer, get_normalizer from .packer import Packer -from .stepper import SingleModuleStepper, SingleModuleStepperConfig __all__ = [ "spherical_area_weights", @@ -20,7 +20,6 @@ "StandardNormalizer", "get_normalizer", "Packer", - "SingleModuleStepper", - "SingleModuleStepperConfig", "ClimateData", + "GriddedOperations", ] diff --git a/fme/fme/core/aggregator/inference/main.py b/fme/fme/core/aggregator/inference/main.py deleted file mode 100644 index badbd27..0000000 --- a/fme/fme/core/aggregator/inference/main.py +++ /dev/null @@ -1,552 +0,0 @@ -import dataclasses -import datetime -import warnings -from typing import Dict, Iterable, List, Literal, Mapping, Optional, Protocol, Union - -import torch -import xarray as xr - -from fme.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata -from fme.core.typing_ import TensorMapping -from fme.core.wandb import Table, WandB - -from ..one_step.reduced import MeanAggregator as OneStepMeanAggregator -from .annual import GlobalMeanAnnualAggregator -from .enso import EnsoCoefficientEvaluatorAggregator -from .histogram import HistogramAggregator -from .reduced import MeanAggregator, SingleTargetMeanAggregator -from .seasonal import SeasonalAggregator -from .spectrum import PairedSphericalPowerSpectrumAggregator -from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator -from .video import VideoAggregator -from .zonal_mean import ZonalMeanAggregator - -wandb = WandB.get_instance() -APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730) -SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800) - - -class _Aggregator(Protocol): - @torch.no_grad() - def record_batch( - self, - data: TensorMapping, - ): - ... - - @torch.no_grad() - def get_logs(self, label: str): - ... - - @torch.no_grad() - def get_dataset(self) -> xr.Dataset: - ... - - -class _EvaluatorAggregator(Protocol): - @torch.no_grad() - def record_batch( - self, - loss: float, - target_data: TensorMapping, - gen_data: TensorMapping, - target_data_norm: TensorMapping, - gen_data_norm: TensorMapping, - i_time_start: int = 0, - ): - ... - - @torch.no_grad() - def get_logs(self, label: str): - ... - - @torch.no_grad() - def get_dataset(self) -> xr.Dataset: - ... - - -class _TimeDependentAggregator(Protocol): - @torch.no_grad() - def record_batch( - self, - time: xr.DataArray, - data: TensorMapping, - ): - ... - - @torch.no_grad() - def get_logs(self, label: str): - ... - - -class _TimeDependentEvaluatorAggregator(Protocol): - @torch.no_grad() - def record_batch( - self, - time: xr.DataArray, - target_data: TensorMapping, - gen_data: TensorMapping, - ): - ... - - @torch.no_grad() - def get_logs(self, label: str): - ... - - -@dataclasses.dataclass -class InferenceEvaluatorAggregatorConfig: - """ - Configuration for inference evaluator aggregator. - - Attributes: - log_histograms: Whether to log histograms of the targets and predictions. - log_video: Whether to log videos of the state evolution. - log_extended_video: Whether to log wandb videos of the predictions with - statistical metrics, only done if log_video is True. - log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a - time dimension. - log_seasonal_means: Whether to log seasonal mean metrics and images. - monthly_reference_data: Path to monthly reference data to compare against. - time_mean_reference_data: Path to reference time means to compare against. - """ - - log_histograms: bool = False - log_video: bool = False - log_extended_video: bool = False - log_zonal_mean_images: bool = True - log_seasonal_means: bool = False - monthly_reference_data: Optional[str] = None - time_mean_reference_data: Optional[str] = None - - def build( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - n_timesteps: int, - initial_times: xr.DataArray, - record_step_20: bool = False, - data_grid: Literal[ - "legendre-gauss", "equiangular", "healpix" - ] = "legendre-gauss", - metadata: Optional[Mapping[str, VariableMetadata]] = None, - ) -> "InferenceEvaluatorAggregator": - if self.monthly_reference_data is None: - monthly_reference_data = None - else: - monthly_reference_data = xr.open_dataset(self.monthly_reference_data) - if self.time_mean_reference_data is None: - time_mean = None - else: - time_mean = xr.open_dataset(self.time_mean_reference_data) - - if n_timesteps > 2**15 and self.log_zonal_mean_images: - # matplotlib raises an error if image size is too large, and we plot - # one pixel per timestep in the zonal mean images. - warnings.warn( - "Disabling zonal mean images logging due to large number of timesteps" - f" (n_timesteps={n_timesteps}). Set log_zonal_mean_images=False or " - "decrease n_timesteps to below 2**15 to avoid this warning." - ) - log_zonal_mean_images = False - else: - log_zonal_mean_images = self.log_zonal_mean_images - - return InferenceEvaluatorAggregator( - area_weights=area_weights, - sigma_coordinates=sigma_coordinates, - timestep=timestep, - n_timesteps=n_timesteps, - initial_times=initial_times, - log_histograms=self.log_histograms, - log_video=self.log_video, - enable_extended_videos=self.log_extended_video, - log_zonal_mean_images=log_zonal_mean_images, - log_seasonal_means=self.log_seasonal_means, - monthly_reference_data=monthly_reference_data, - time_mean_reference_data=time_mean, - record_step_20=record_step_20, - metadata=metadata, - data_grid=data_grid, - ) - - -class InferenceEvaluatorAggregator: - """ - Aggregates statistics for inference comparing a generated and target series. - - To use, call `record_batch` on the results of each batch, then call - `get_logs` to get a dictionary of statistics when you're done. - """ - - def __init__( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - n_timesteps: int, - initial_times: xr.DataArray, - record_step_20: bool = False, - log_video: bool = False, - enable_extended_videos: bool = False, - log_zonal_mean_images: bool = False, - log_seasonal_means: bool = False, - metadata: Optional[Mapping[str, VariableMetadata]] = None, - monthly_reference_data: Optional[xr.Dataset] = None, - log_histograms: bool = False, - data_grid: Literal[ - "legendre-gauss", "equiangular", "healpix" - ] = "legendre-gauss", - time_mean_reference_data: Optional[xr.Dataset] = None, - ): - """ - Args: - area_weights: Area weights for each grid cell. - sigma_coordinates: Data sigma coordinates - timestep: Timestep of the model. - n_timesteps: Number of timesteps of inference that will be run. - initial_times: Initial times for each sample. - record_step_20: Whether to record the mean of the 20th steps. - log_video: Whether to log videos of the state evolution. - enable_extended_videos: Whether to log videos of statistical - metrics of state evolution - log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a - time dimension. - log_seasonal_means: Whether to log seasonal means metrics and images. - metadata: Mapping of variable names their metadata that will - used in generating logged image captions. - monthly_reference_data: Reference monthly data for computing target stats. - log_histograms: Whether to aggregate histograms. - data_grid: The grid type of the data, used for spherical power spectrum. - time_mean_reference_data: Reference time means for computing bias stats. - """ - self._aggregators: Dict[str, _EvaluatorAggregator] = { - "mean": MeanAggregator( - area_weights, - target="denorm", - n_timesteps=n_timesteps, - metadata=metadata, - ), - "mean_norm": MeanAggregator( - area_weights, - target="norm", - n_timesteps=n_timesteps, - metadata=metadata, - ), - "time_mean": TimeMeanEvaluatorAggregator( - area_weights, - metadata=metadata, - reference_means=time_mean_reference_data, - ), - "time_mean_norm": TimeMeanEvaluatorAggregator( - area_weights, - target="norm", - metadata=metadata, - ), - } - if len(area_weights.shape) == 2: - self._aggregators[ - "spherical_power_spectrum" - ] = PairedSphericalPowerSpectrumAggregator( - area_weights.shape[-2], - area_weights.shape[-1], - data_grid, - ) - else: - warnings.warn( - "Area weights are not 2D, spherical power spectrum will not be computed" - ) - if record_step_20: - self._aggregators["mean_step_20"] = OneStepMeanAggregator( - area_weights, target_time=20 - ) - if log_video: - self._aggregators["video"] = VideoAggregator( - n_timesteps=n_timesteps, - enable_extended_videos=enable_extended_videos, - metadata=metadata, - ) - if log_zonal_mean_images: - self._aggregators["zonal_mean"] = ZonalMeanAggregator( - n_timesteps=n_timesteps, - metadata=metadata, - ) - if log_histograms: - self._aggregators["histogram"] = HistogramAggregator() - self._time_dependent_aggregators: Dict[ - str, _TimeDependentEvaluatorAggregator - ] = {} - if log_seasonal_means: - self._time_dependent_aggregators["seasonal"] = SeasonalAggregator( - area_weights=area_weights, - metadata=metadata, - ) - if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS: - self._time_dependent_aggregators["annual"] = GlobalMeanAnnualAggregator( - area_weights=area_weights, - timestep=timestep, - metadata=metadata, - monthly_reference_data=monthly_reference_data, - ) - if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS: - self._time_dependent_aggregators[ - "enso_coefficient" - ] = EnsoCoefficientEvaluatorAggregator( - initial_times, - n_timesteps - 1, - timestep, - area_weights, - metadata=metadata, - ) - - @torch.no_grad() - def record_batch( - self, - loss: float, - time: xr.DataArray, - target_data: TensorMapping, - gen_data: TensorMapping, - target_data_norm: TensorMapping, - gen_data_norm: TensorMapping, - i_time_start: int = 0, - ): - if len(target_data) == 0: - raise ValueError("No data in target_data") - if len(gen_data) == 0: - raise ValueError("No data in gen_data") - target_data = {k: v for k, v in target_data.items() if k in gen_data} - target_data_norm = {k: v for k, v in target_data_norm.items() if k in gen_data} - for aggregator in self._aggregators.values(): - aggregator.record_batch( - loss=loss, - target_data=target_data, - gen_data=gen_data, - target_data_norm=target_data_norm, - gen_data_norm=gen_data_norm, - i_time_start=i_time_start, - ) - for time_dependent_aggregator in self._time_dependent_aggregators.values(): - time_dependent_aggregator.record_batch( - time=time, - target_data=target_data, - gen_data=gen_data, - ) - - @torch.no_grad() - def get_logs(self, label: str): - """ - Returns logs as can be reported to WandB. - - Args: - label: Label to prepend to all log keys. - """ - logs = {} - for name, aggregator in self._aggregators.items(): - logs.update(aggregator.get_logs(label=name)) - for name, time_dependent_aggregator in self._time_dependent_aggregators.items(): - logs.update(time_dependent_aggregator.get_logs(label=name)) - logs = {f"{label}/{key}": val for key, val in logs.items()} - return logs - - @torch.no_grad() - def get_inference_logs(self, label: str) -> List[Dict[str, Union[float, int]]]: - """ - Returns a list of logs to report to WandB. - - This is done because in inference, we use the wandb step - as the time step, meaning we need to re-organize the logged data - from tables into a list of dictionaries. - """ - return to_inference_logs(self.get_logs(label=label)) - - @torch.no_grad() - def get_datasets( - self, aggregator_whitelist: Optional[Iterable[str]] = None - ) -> Dict[str, xr.Dataset]: - """ - Args: - aggregator_whitelist: aggregator names to include in the output. If - None, return all the datasets associated with all aggregators. - """ - datasets = ( - (name, agg.get_dataset()) for name, agg in self._aggregators.items() - ) - if aggregator_whitelist is not None: - filter_ = set(aggregator_whitelist) - return {name: ds for name, ds in datasets if name in filter_} - - return {name: ds for name, ds in datasets} - - -def to_inference_logs( - log: Mapping[str, Union[Table, float, int]] -) -> List[Dict[str, Union[float, int]]]: - # we have a dictionary which contains WandB tables - # which we will convert to a list of dictionaries, one for each - # row in the tables. Any scalar values will be reported in the last - # dictionary. - n_rows = 0 - for val in log.values(): - if isinstance(val, Table): - n_rows = max(n_rows, len(val.data)) - logs: List[Dict[str, Union[float, int]]] = [] - for i in range(n_rows): - logs.append({}) - for key, val in log.items(): - if isinstance(val, Table): - for i, row in enumerate(val.data): - for j, col in enumerate(val.columns): - key_without_table_name = key[: key.rfind("/")] - logs[i][f"{key_without_table_name}/{col}"] = row[j] - else: - logs[-1][key] = val - return logs - - -def table_to_logs(table: Table) -> List[Dict[str, Union[float, int]]]: - """ - Converts a WandB table into a list of dictionaries. - """ - logs = [] - for row in table.data: - logs.append({table.columns[i]: row[i] for i in range(len(row))}) - return logs - - -@dataclasses.dataclass -class InferenceAggregatorConfig: - """ - Configuration for inference aggregator. - - Attributes: - time_mean_reference_data: Path to reference time means to compare against. - """ - - time_mean_reference_data: Optional[str] = None - - def build( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - n_timesteps: int, - metadata: Optional[Mapping[str, VariableMetadata]] = None, - ) -> "InferenceAggregator": - if self.time_mean_reference_data is not None: - time_means = xr.open_dataset(self.time_mean_reference_data) - else: - time_means = None - return InferenceAggregator( - area_weights=area_weights, - sigma_coordinates=sigma_coordinates, - timestep=timestep, - n_timesteps=n_timesteps, - metadata=metadata, - time_mean_reference_data=time_means, - ) - - -class InferenceAggregator: - """ - Aggregates statistics on a single timeseries of data. - - To use, call `record_batch` on the results of each batch, then call - `get_logs` to get a dictionary of statistics when you're done. - """ - - def __init__( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - n_timesteps: int, - metadata: Optional[Mapping[str, VariableMetadata]] = None, - time_mean_reference_data: Optional[xr.Dataset] = None, - ): - """ - Args: - area_weights: Area weights for each grid cell. - sigma_coordinates: Data sigma coordinates - timestep: Timestep of the model. - metadata: Mapping of variable names their metadata that will - used in generating logged image captions. - time_mean_reference_data: Reference time means for computing bias stats. - """ - self._aggregators: Dict[str, _Aggregator] = { - "mean": SingleTargetMeanAggregator( - area_weights, - n_timesteps=n_timesteps, - ), - "time_mean": TimeMeanAggregator( - area_weights, - metadata=metadata, - reference_means=time_mean_reference_data, - ), - } - self._time_dependent_aggregators: Dict[str, _TimeDependentAggregator] = {} - - @torch.no_grad() - def record_batch( - self, - time: xr.DataArray, - data: TensorMapping, - i_time_start: int, - ): - if len(data) == 0: - raise ValueError("data is empty") - for aggregator in self._aggregators.values(): - aggregator.record_batch( - data=data, - i_time_start=i_time_start, - ) - for time_dependent_aggregator in self._time_dependent_aggregators.values(): - time_dependent_aggregator.record_batch( - time=time, - data=data, - ) - - @torch.no_grad() - def get_logs(self, label: str): - """ - Returns logs as can be reported to WandB. - - Args: - label: Label to prepend to all log keys. - """ - logs = {} - for name, aggregator in self._aggregators.items(): - logs.update(aggregator.get_logs(label=name)) - for name, time_dependent_aggregator in self._time_dependent_aggregators.items(): - logs.update(time_dependent_aggregator.get_logs(label=name)) - logs = {f"{label}/{key}": val for key, val in logs.items()} - return logs - - @torch.no_grad() - def get_inference_logs(self, label: str) -> List[Dict[str, Union[float, int]]]: - """ - Returns a list of logs to report to WandB. - - This is done because in inference, we use the wandb step - as the time step, meaning we need to re-organize the logged data - from tables into a list of dictionaries. - """ - return to_inference_logs(self.get_logs(label=label)) - - @torch.no_grad() - def get_datasets( - self, aggregator_whitelist: Optional[Iterable[str]] = None - ) -> Dict[str, xr.Dataset]: - """ - Args: - aggregator_whitelist: aggregator names to include in the output. If - None, return all the datasets associated with all aggregators. - """ - datasets = ( - (name, agg.get_dataset()) for name, agg in self._aggregators.items() - ) - if aggregator_whitelist is not None: - filter_ = set(aggregator_whitelist) - return {name: ds for name, ds in datasets if name in filter_} - - return {name: ds for name, ds in datasets} diff --git a/fme/fme/core/aggregator/inference/test_evaluator.py b/fme/fme/core/aggregator/inference/test_evaluator.py deleted file mode 100644 index c6edb0b..0000000 --- a/fme/fme/core/aggregator/inference/test_evaluator.py +++ /dev/null @@ -1,206 +0,0 @@ -import datetime - -import numpy as np -import pytest -import torch -import xarray as xr - -import fme -from fme.core.aggregator.inference import InferenceEvaluatorAggregator -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.device import get_device - -TIMESTEP = datetime.timedelta(hours=6) - - -def get_zero_time(shape, dims): - return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims) - - -def test_logs_labels_exist(): - n_sample = 10 - n_time = 22 - nx = 2 - ny = 2 - nz = 3 - loss = 1.0 - area_weights = torch.ones(ny).to(fme.get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - initial_times = get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]) - - agg = InferenceEvaluatorAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - n_time, - initial_times, - record_step_20=True, - log_video=True, - log_zonal_mean_images=True, - ) - target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) - agg.record_batch(loss, time, target_data, gen_data, target_data_norm, gen_data_norm) - logs = agg.get_logs(label="test") - assert "test/mean/series" in logs - assert "test/mean_norm/series" in logs - assert "test/mean_step_20/weighted_rmse/a" in logs - assert "test/mean_step_20/weighted_bias/a" in logs - assert "test/mean_step_20/weighted_grad_mag_percent_diff/a" in logs - table = logs["test/mean/series"] - assert table.columns == [ - "forecast_step", - "weighted_bias/a", - "weighted_grad_mag_percent_diff/a", - "weighted_mean_gen/a", - "weighted_mean_target/a", - "weighted_rmse/a", - "weighted_std_gen/a", - ] - assert "test/time_mean/rmse/a" in logs - assert "test/time_mean/bias/a" in logs - assert "test/time_mean/bias_map/a" in logs - assert "test/time_mean/gen_map/a" in logs - assert "test/zonal_mean/error/a" in logs - assert "test/zonal_mean/gen/a" in logs - assert "test/video/a" in logs - - -def test_inference_logs_labels_exist(): - n_sample = 10 - n_time = 22 - nx = 2 - ny = 2 - nz = 3 - loss = 1.0 - area_weights = torch.ones(ny).to(fme.get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - initial_times = (get_zero_time(shape=[n_sample, 0], dims=["sample", "time"]),) - agg = InferenceEvaluatorAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - n_time, - initial_times, - record_step_20=True, - log_video=True, - ) - target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) - agg.record_batch(loss, time, target_data, gen_data, target_data_norm, gen_data_norm) - logs = agg.get_inference_logs(label="test") - assert isinstance(logs, list) - assert len(logs) == n_time - assert "test/mean/weighted_bias/a" in logs[0] - assert "test/mean/weighted_mean_gen/a" in logs[0] - assert "test/mean/weighted_mean_target/a" in logs[0] - assert "test/mean/weighted_grad_mag_percent_diff/a" in logs[0] - assert "test/mean/weighted_rmse/a" in logs[0] - assert "test/mean_norm/weighted_bias/a" in logs[0] - assert "test/mean_norm/weighted_mean_gen/a" in logs[0] - assert "test/mean_norm/weighted_mean_target/a" in logs[0] - assert "test/mean_norm/weighted_rmse/a" in logs[0] - # series/table data should be rolled out, not included as a table - assert "test/mean/series" not in logs[0] - assert "test/mean_norm/series" not in logs[0] - assert "test/reduced/series" not in logs[0] - assert "test/reduced_norm/series" not in logs[0] - - -@pytest.mark.parametrize( - "window_len, n_windows", - [ - pytest.param(3, 1, id="single_window"), - pytest.param(3, 2, id="two_windows"), - ], -) -def test_i_time_start_gets_correct_time_longer_windows(window_len: int, n_windows: int): - # while this directly tests the "mean" result, this is really a test that - # the data from the correct timestep is piped into the aggregator. - overlap = 1 # tested code assumes windows have one overlapping point - area_weights = torch.ones(4).to(fme.get_device()) - nz = 3 - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - initial_times = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),) - agg = InferenceEvaluatorAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - (window_len - overlap) * n_windows + 1, - initial_times, - ) - target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())} - time = get_zero_time(shape=[2, window_len], dims=["sample", "time"]) - i_start = 0 - for i in range(n_windows): - sample_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())} - for i in range(window_len): - sample_data["a"][..., i, :, :] = float(i_start + i) - agg.record_batch( - 1.0, - time=time, - target_data=target_data, - gen_data=sample_data, - target_data_norm=target_data, - gen_data_norm=sample_data, - i_time_start=i_start, - ) - i_start += window_len - overlap # subtract 1 for overlapping windows - logs = agg.get_logs(label="metrics") - table = logs["metrics/mean/series"] - # get the weighted_bias column - bias = table.get_column("weighted_bias/a") - assert len(bias) == (window_len - overlap) * n_windows + overlap - for i in range(len(bias)): - np.testing.assert_allclose(bias[i], float(i), rtol=1e-5) - - -@pytest.mark.parametrize( - "window_len, n_windows, overlap", - [ - pytest.param(3, 1, 0, id="single_window"), - pytest.param(3, 2, 0, id="two_windows"), - pytest.param(3, 2, 1, id="two_windows_overlap"), - ], -) -def test_inference_logs_length(window_len: int, n_windows: int, overlap: int): - """ - Test that the inference logs are the correct length when using one or more - possibly-overlapping windows. - """ - area_weights = torch.ones(4).to(fme.get_device()) - nz = 3 - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - initial_times = (get_zero_time(shape=[2, 0], dims=["sample", "time"]),) - agg = InferenceEvaluatorAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - (window_len - overlap) * n_windows + overlap, - initial_times, - ) - target_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())} - time = get_zero_time(shape=[2, window_len], dims=["sample", "time"]) - i_start = 0 - for i in range(n_windows): - sample_data = {"a": torch.zeros([2, window_len, 4, 4], device=get_device())} - for i in range(window_len): - sample_data["a"][..., i, :, :] = float(i_start + i) - agg.record_batch( - 1.0, - time=time, - target_data=target_data, - gen_data=sample_data, - target_data_norm=target_data, - gen_data_norm=sample_data, - i_time_start=i_start, - ) - i_start += window_len - overlap # subtract 1 for overlapping windows - logs = agg.get_inference_logs(label="metrics") - assert len(logs) == (window_len - overlap) * n_windows + overlap diff --git a/fme/fme/core/aggregator/inference/test_inference.py b/fme/fme/core/aggregator/inference/test_inference.py deleted file mode 100644 index d6781c0..0000000 --- a/fme/fme/core/aggregator/inference/test_inference.py +++ /dev/null @@ -1,101 +0,0 @@ -import datetime - -import numpy as np -import torch -import xarray as xr - -import fme -from fme.core.aggregator.inference import InferenceAggregator -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.device import get_device - -TIMESTEP = datetime.timedelta(hours=6) - - -def get_zero_time(shape, dims): - return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims) - - -def test_logs_labels_exist(): - n_sample = 10 - n_time = 22 - nx = 2 - ny = 2 - nz = 3 - area_weights = torch.ones(ny).to(fme.get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - agg = InferenceAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - n_timesteps=n_time, - ) - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) - agg.record_batch(time, data=gen_data, i_time_start=0) - logs = agg.get_logs(label="test") - assert "test/mean/series" in logs - assert "test/time_mean/gen_map/a" in logs - assert "test/time_mean/ref_bias_map/a" not in logs - assert "test/time_mean/ref_bias/a" not in logs - assert "test/time_mean/ref_rmse/a" not in logs - - -def test_logs_labels_exist_with_reference_time_means(): - n_sample = 10 - n_time = 22 - nx = 2 - ny = 2 - nz = 3 - area_weights = torch.ones(ny).to(fme.get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - reference_time_means = xr.Dataset( - { - "a": xr.DataArray( - np.random.randn(ny, nx), - dims=["grid_yt", "grid_xt"], - ) - } - ) - agg = InferenceAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - n_timesteps=n_time, - time_mean_reference_data=reference_time_means, - ) - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) - agg.record_batch(time, data=gen_data, i_time_start=0) - logs = agg.get_logs(label="test") - assert "test/mean/series" in logs - assert "test/time_mean/gen_map/a" in logs - assert "test/time_mean/ref_bias_map/a" in logs - assert "test/time_mean/ref_bias/a" in logs - assert "test/time_mean/ref_rmse/a" in logs - - -def test_inference_logs_labels_exist(): - n_sample = 10 - n_time = 22 - nx = 2 - ny = 2 - nz = 3 - area_weights = torch.ones(ny).to(fme.get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - agg = InferenceAggregator( - area_weights, - sigma_coordinates, - TIMESTEP, - n_timesteps=n_time, - ) - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) - agg.record_batch(time, data=gen_data, i_time_start=0) - logs = agg.get_inference_logs(label="test") - assert isinstance(logs, list) - assert len(logs) == n_time - assert "test/mean/weighted_mean_gen/a" in logs[0] - assert "test/mean/weighted_mean_gen/a" in logs[-1] - # assert len(logs) == n_time use this assertion when timeseries data is generated - assert "test/time_mean/gen_map/a" in logs[-1] diff --git a/fme/fme/core/aggregator/one_step/derived.py b/fme/fme/core/aggregator/one_step/derived.py deleted file mode 100644 index 92de5a7..0000000 --- a/fme/fme/core/aggregator/one_step/derived.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Derived metrics take the global state as input and usually output a new -variable, e.g. dry air mass.""" - -import abc -from dataclasses import dataclass -from typing import Dict, List, Mapping, Optional, Tuple - -import torch - -from fme.core.climate_data import ( - CLIMATE_FIELD_NAME_PREFIXES, - ClimateData, - compute_dry_air_absolute_differences, -) -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.device import get_device -from fme.core.typing_ import TensorMapping - - -@dataclass -class _TargetGenPair: - target: torch.Tensor - gen: torch.Tensor - - -class DerivedMetric(abc.ABC): - """Derived metrics are computed from the global state and usually output a - new variable, e.g. dry air tendencies.""" - - @abc.abstractmethod - def record(self, target: ClimateData, gen: ClimateData) -> None: - ... - - @abc.abstractmethod - def get(self) -> _TargetGenPair: - """Returns the derived metric applied to the target and data generated - by the model.""" - ... - - -class DryAir(DerivedMetric): - """Computes absolute value of the dry air tendency of the first time step, - averaged over the batch. If the data does not contain the required fields, - then returns NaN.""" - - def __init__( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - device: torch.device, - spatial_dims=(2, 3), - ): - self._area_weights = area_weights - self._sigma_coordinates = sigma_coordinates - self._dry_air_target_total: Optional[torch.Tensor] = None - self._dry_air_gen_total: Optional[torch.Tensor] = None - self._device = device - self._spatial_dims: Tuple[int, int] = spatial_dims - - def record(self, target: ClimateData, gen: ClimateData) -> None: - def _compute_dry_air_helper(climate_data: ClimateData) -> torch.Tensor: - return compute_dry_air_absolute_differences( - climate_data, - area=self._area_weights, - sigma_coordinates=self._sigma_coordinates, - )[0] - - dry_air_target = _compute_dry_air_helper(target) - dry_air_gen = _compute_dry_air_helper(gen) - - # initialize - if self._dry_air_target_total is None: - self._dry_air_target_total = torch.zeros_like( - dry_air_target, device=self._device - ) - if self._dry_air_gen_total is None: - self._dry_air_gen_total = torch.zeros_like(dry_air_gen, device=self._device) - - self._dry_air_target_total += dry_air_target - self._dry_air_gen_total += dry_air_gen - - def get(self) -> _TargetGenPair: - if self._dry_air_target_total is None or self._dry_air_gen_total is None: - raise ValueError("No batches have been recorded.") - return _TargetGenPair( - target=self._dry_air_target_total, gen=self._dry_air_gen_total - ) - - -class DerivedMetricsAggregator: - def __init__( - self, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - climate_field_name_prefixes: Mapping[ - str, List[str] - ] = CLIMATE_FIELD_NAME_PREFIXES, - ): - self.area_weights = area_weights - self.sigma_coordinates = sigma_coordinates - self.climate_field_name_prefixes = climate_field_name_prefixes - device = get_device() - self._derived_metrics: Dict[str, DerivedMetric] = { - "surface_pressure_due_to_dry_air": DryAir( - self.area_weights, self.sigma_coordinates, device=device - ) - } - self._n_batches = 0 - - @torch.no_grad() - def record_batch( - self, - loss: float, - target_data: TensorMapping, - gen_data: TensorMapping, - target_data_norm: TensorMapping, - gen_data_norm: TensorMapping, - ): - del loss, target_data_norm, gen_data_norm # unused - target = ClimateData(target_data, self.climate_field_name_prefixes) - gen = ClimateData(gen_data, self.climate_field_name_prefixes) - - for metric_fn in self._derived_metrics.values(): - metric_fn.record(target, gen) - - # only increment n_batches if we actually recorded a batch - self._n_batches += 1 - - def get_logs(self, label: str): - logs = dict() - for metric_name in self._derived_metrics: - values = self._derived_metrics[metric_name].get() - logs[f"{label}/{metric_name}/target"] = values.target / self._n_batches - logs[f"{label}/{metric_name}/gen"] = values.gen / self._n_batches - return logs diff --git a/fme/fme/core/aggregator/one_step/test_main.py b/fme/fme/core/aggregator/one_step/test_main.py deleted file mode 100644 index 628b2fc..0000000 --- a/fme/fme/core/aggregator/one_step/test_main.py +++ /dev/null @@ -1,183 +0,0 @@ -import pytest -import torch - -from fme.core.aggregator.one_step import OneStepAggregator -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.device import get_device - - -def test_labels_exist(): - n_sample = 10 - n_time = 3 - nx, ny, nz = 2, 2, 3 - loss = 1.0 - area_weights = torch.ones(ny).to(get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - agg = OneStepAggregator(area_weights, sigma_coordinates) - target_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - target_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - gen_data_norm = {"a": torch.randn(n_sample, n_time, nx, ny, device=get_device())} - agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm) - logs = agg.get_logs(label="test") - assert "test/mean/loss" in logs - assert "test/mean/weighted_rmse/a" in logs - assert "test/mean/weighted_bias/a" in logs - assert "test/mean/weighted_grad_mag_percent_diff/a" in logs - assert "test/snapshot/image-full-field/a" in logs - assert "test/snapshot/image-residual/a" in logs - assert "test/snapshot/image-error/a" in logs - - -def test_loss(): - """ - Basic test the aggregator combines loss correctly - with multiple batches and no distributed training. - """ - torch.manual_seed(0) - example_data = { - "a": torch.randn(1, 2, 5, 5, device=get_device()), - } - area_weights = torch.ones(1).to(get_device()) - nz = 3 - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - aggregator = OneStepAggregator(area_weights, sigma_coordinates) - aggregator.record_batch( - loss=1.0, - target_data=example_data, - gen_data=example_data, - target_data_norm=example_data, - gen_data_norm=example_data, - ) - aggregator.record_batch( - loss=2.0, - target_data=example_data, - gen_data=example_data, - target_data_norm=example_data, - gen_data_norm=example_data, - ) - logs = aggregator.get_logs(label="metrics") - assert logs["metrics/mean/loss"] == 1.5 - aggregator.record_batch( - loss=3.0, - target_data=example_data, - gen_data=example_data, - target_data_norm=example_data, - gen_data_norm=example_data, - ) - logs = aggregator.get_logs(label="metrics") - assert logs["metrics/mean/loss"] == 2.0 - - -def test_aggregator_raises_on_no_data(): - """ - Basic test the aggregator combines loss correctly - with multiple batches and no distributed training. - """ - ny, nz = 2, 3 - area_weights = torch.ones(ny).to(get_device()) - sigma_coordinates = SigmaCoordinates(torch.arange(nz + 1), torch.arange(nz + 1)) - agg = OneStepAggregator(area_weights, sigma_coordinates) - with pytest.raises(ValueError) as excinfo: - agg.record_batch( - loss=1.0, target_data={}, gen_data={}, target_data_norm={}, gen_data_norm={} - ) - # check that the raised exception contains the right substring - assert "No data" in str(excinfo.value) - - -def test_derived(): - n_sample = 5 - n_time = 3 - nx, ny, nz = 2, 4, 3 - loss = 1.0 - area_weights = torch.ones(ny).to(get_device()) - sigma_coordinates = SigmaCoordinates( - torch.arange(nz + 1).to(get_device()), torch.arange(nz + 1).to(get_device()) - ) - agg = OneStepAggregator(area_weights, sigma_coordinates) - - def _make_data(): - fields = ["a", "PRESsfc"] + [f"specific_total_water_{i}" for i in range(nz)] - return { - field: torch.randn(n_sample, n_time, nx, ny, device=get_device()) - for field in fields - } - - target_data = _make_data() - gen_data = _make_data() - target_data_norm = _make_data() - gen_data_norm = _make_data() - - agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm) - - logs = agg.get_logs("") - target = logs["/derived/surface_pressure_due_to_dry_air/target"] - gen = logs["/derived/surface_pressure_due_to_dry_air/gen"] - - assert target.shape == () - assert not torch.isnan(target).any() - assert gen.shape == () - assert not torch.isnan(target).any() - - -def test_derived_missing_surface_pressure(): - n_sample = 5 - n_time = 3 - nx, ny, nz = 2, 4, 3 - loss = 1.0 - area_weights = torch.ones(ny).to(get_device()) - sigma_coordinates = SigmaCoordinates( - torch.arange(nz + 1).to(get_device()), torch.arange(nz + 1).to(get_device()) - ) - agg = OneStepAggregator(area_weights, sigma_coordinates) - - def _make_data(): - fields = ["a"] # N.B. no surface pressure or water fields. - return { - field: torch.randn(n_sample, n_time, nx, ny, device=get_device()) - for field in fields - } - - target_data = _make_data() - gen_data = _make_data() - target_data_norm = _make_data() - gen_data_norm = _make_data() - - agg.record_batch(loss, target_data, gen_data, target_data_norm, gen_data_norm) - - logs = agg.get_logs("") - target = logs["/derived/surface_pressure_due_to_dry_air/target"] - gen = logs["/derived/surface_pressure_due_to_dry_air/gen"] - - assert target.shape == () and torch.isnan(target).all() - assert gen.shape == () and torch.isnan(target).all() - - -def test__get_loss_scaled_mse_components(): - x = torch.ones(10).to(get_device()) - loss_scaling = { - "a": torch.tensor(1.0), - "b": torch.tensor(0.5), - } - agg = OneStepAggregator( - area_weights=torch.ones(10).to(get_device()), - sigma_coordinates=SigmaCoordinates(x, x), - loss_scaling=loss_scaling, - ) - - logs = { - "test/mean/weighted_rmse/a": 1.0, - "test/mean/weighted_rmse/b": 4.0, - "test/mean/weighted_rmse/c": 0.0, - } - result = agg._get_loss_scaled_mse_components(logs, "test") - scaled_squared_errors_sum = (1.0 / 1.0) ** 2 + (4.0 / 0.5) ** 2 - assert ( - result["test/mean/mse_fractional_components/a"] == 1 / scaled_squared_errors_sum - ) - assert ( - result["test/mean/mse_fractional_components/b"] - == 64 / scaled_squared_errors_sum - ) - assert "test/mean/mse_fractional_components/c" not in result diff --git a/fme/fme/core/aggregator/test_plotting.py b/fme/fme/core/aggregator/test_plotting.py deleted file mode 100644 index 612b3cd..0000000 --- a/fme/fme/core/aggregator/test_plotting.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import pytest - -from .plotting import _stitch_data_panels, get_cmap_limits, plot_imshow - - -def test_cmap_limits(): - data = np.array([1, 2, 3]) - vmin, vmax = get_cmap_limits(data) - assert vmin == 1 - assert vmax == 3 - - -def test_cmap_limits_diverging(): - data = np.array([-1, 2, 3]) - vmin, vmax = get_cmap_limits(data, diverging=True) - assert vmin == -3 - assert vmax == 3 - - -@pytest.mark.parametrize("use_colorbar", [True, False]) -def test_plot_imshow(use_colorbar): - shape = [10, 15] - data = np.random.randn(*shape) - fig = plot_imshow(np.array(data), use_colorbar=use_colorbar) - width, height = (fig.get_size_inches() * fig.dpi).astype(int) - if use_colorbar: - # colorbar is no more than 15% of the width but greater than 0 pixels - assert shape[1] < width <= int(shape[1] * 1.15) - assert height == shape[0] - else: - assert [height, width] == shape - - -def test_stitch_data_panels(): - data = [ - [np.array([[1, 2]]), np.array([[3, 4]])], - [np.array([[5, 6]]), np.array([[7, 8]])], - ] - stitched = _stitch_data_panels(data, vmin=1) - expected = np.array( - [ # vertical orientation is swapped as data starts from bottom-left - [5, 6, 1, 7, 8], - [1, 1, 1, 1, 1], - [1, 2, 1, 3, 4], - ] - ) - assert np.array_equal(stitched, expected) diff --git a/fme/fme/core/climate_data.py b/fme/fme/core/climate_data.py index ebb5077..9891f05 100644 --- a/fme/fme/core/climate_data.py +++ b/fme/fme/core/climate_data.py @@ -1,6 +1,5 @@ -import re from types import MappingProxyType -from typing import List, Mapping, Union +from typing import Callable, List, Mapping import torch @@ -12,7 +11,8 @@ RVGAS, SPECIFIC_HEAT_OF_DRY_AIR_CONST_PRESSURE, ) -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.stacker import Stacker from fme.core.typing_ import TensorDict, TensorMapping CLIMATE_FIELD_NAME_PREFIXES = MappingProxyType( @@ -20,47 +20,29 @@ "specific_total_water": ["specific_total_water_"], "surface_pressure": ["PRESsfc", "PS"], "surface_height": ["HGTsfc"], + "surface_geopotential": ["PHIS"], "tendency_of_total_water_path_due_to_advection": [ "tendency_of_total_water_path_due_to_advection" ], "latent_heat_flux": ["LHTFLsfc", "LHFLX"], - "sensible_heat_flux": ["SHTFLsfc"], + "sensible_heat_flux": ["SHTFLsfc", "SHFLX"], "precipitation_rate": ["PRATEsfc", "surface_precipitation_rate"], - "sfc_down_sw_radiative_flux": ["DSWRFsfc"], - "sfc_up_sw_radiative_flux": ["USWRFsfc"], - "sfc_down_lw_radiative_flux": ["DLWRFsfc"], - "sfc_up_lw_radiative_flux": ["ULWRFsfc"], - "air_temperature": ["air_temperature_"], + "sfc_down_sw_radiative_flux": ["DSWRFsfc", "FSDS"], + "sfc_up_sw_radiative_flux": ["USWRFsfc", "surface_upward_shortwave_flux"], + "sfc_down_lw_radiative_flux": ["DLWRFsfc", "FLDS"], + "sfc_up_lw_radiative_flux": ["ULWRFsfc", "surface_upward_longwave_flux"], + "toa_up_lw_radiative_flux": ["ULWRFtoa", "FLUT"], + "toa_up_sw_radiative_flux": ["USWRFtoa", "top_of_atmos_upward_shortwave_flux"], + "toa_down_sw_radiative_flux": ["DSWRFtoa", "SOLIN"], + "air_temperature": ["air_temperature_", "T_"], } ) -def natural_sort(alist: List[str]) -> List[str]: - """Sort to alphabetical order but with numbers sorted - numerically, e.g. a11 comes after a2. See [1] and [2]. - - [1] https://stackoverflow.com/questions/11150239/natural-sorting - [2] https://en.wikipedia.org/wiki/Natural_sort_order - """ - - def convert(text: str) -> Union[str, int]: - if text.isdigit(): - return int(text) - else: - return text.lower() - - def alphanum_key(item: str) -> List[Union[str, int]]: - return [convert(c) for c in re.split("([0-9]+)", item)] - - return sorted(alist, key=alphanum_key) - - -LEVEL_PATTERN = re.compile(r"_(\d+)$") - - class ClimateData: """Container for climate data for accessing variables and providing - torch.Tensor views on data with multiple vertical levels.""" + torch.Tensor views on data with multiple vertical levels. + """ def __init__( self, @@ -74,87 +56,66 @@ def __init__( Args: climate_data: Mapping from field names to tensors. - climate_field_name_prefixes: Mapping from field name prefixes (e.g. - "specific_total_water_") to standardized prefixes, e.g. "PRESsfc" → - "surface_pressure". + climate_field_name_prefixes: Mapping which defines the correspondence + between an arbitrary set of "standard" names (e.g., "surface_pressure" + or "air_temperature") and lists of possible names or prefix variants + (e.g., ["PRESsfc", "PS"] or ["air_temperature_", "T_"]) found in the + data. """ self._data = dict(climate_data) - self._prefixes = climate_field_name_prefixes - - def _extract_levels(self, name: List[str]) -> torch.Tensor: - for prefix in name: - try: - return self._extract_prefix_levels(prefix) - except KeyError: - pass - raise KeyError(name) - - def _extract_prefix_levels(self, prefix: str) -> torch.Tensor: - names = [ - field_name for field_name in self._data if field_name.startswith(prefix) - ] - - levels = [] - for name in names: - match = LEVEL_PATTERN.search(name) - if match is None: - raise ValueError( - f"Invalid field name {name}, is a prefix variable " - "but does not end in _(number)." - ) - levels.append(int(match.group(1))) + self._prefix_map = climate_field_name_prefixes + self._stacker = Stacker(climate_field_name_prefixes) - for i, level in enumerate(sorted(levels)): - if i != level: - raise KeyError(f"Missing level {i} in {prefix} levels {levels}.") - - if len(names) == 0: - raise KeyError(prefix) + @property + def data(self) -> TensorDict: + """Mapping from field names to tensors.""" + return self._data - names = natural_sort(names) - return torch.stack([self._data[name] for name in names], dim=-1) - - def _get(self, name): - for prefix in self._prefixes[name]: - if prefix in self._data.keys(): - return self._get_prefix(prefix) - raise KeyError(name) + def __getitem__(self, name: str): + return getattr(self, name) def _get_prefix(self, prefix): - return self._data[prefix] + return self.data[prefix] def _set(self, name, value): - for prefix in self._prefixes[name]: - if prefix in self._data.keys(): + for prefix in self._prefix_map[name]: + if prefix in self.data.keys(): self._set_prefix(prefix, value) return raise KeyError(name) def _set_prefix(self, prefix, value): - self._data[prefix] = value + self.data[prefix] = value - @property - def data(self) -> TensorDict: - """Mapping from field names to tensors.""" - return self._data + def _get(self, name): + for prefix in self._prefix_map[name]: + if prefix in self.data.keys(): + return self._get_prefix(prefix) + raise KeyError(name) @property def air_temperature(self) -> torch.Tensor: """Returns all vertical levels of air_temperature, e.g. a tensor of - shape `(..., vertical_level)`.""" - prefix = self._prefixes["air_temperature"] - return self._extract_levels(prefix) + shape `(..., vertical_level)`. + """ + return self._stacker("air_temperature", self.data) @property def specific_total_water(self) -> torch.Tensor: """Returns all vertical levels of specific total water, e.g. a tensor of - shape `(..., vertical_level)`.""" - prefix = self._prefixes["specific_total_water"] - return self._extract_levels(prefix) + shape `(..., vertical_level)`. + """ + return self._stacker("specific_total_water", self.data) @property def surface_height(self) -> torch.Tensor: - return self._get("surface_height") + try: + return self._get("surface_height") + except KeyError: + # E3SM saves geopotential not surface height so need to convert + # by using g value from e3sm + GRAVITY_E3SM = 9.80616 + return self._get("surface_geopotential") / GRAVITY_E3SM @property def surface_pressure(self) -> torch.Tensor: @@ -164,22 +125,45 @@ def surface_pressure(self) -> torch.Tensor: def surface_pressure(self, value: torch.Tensor): self._set("surface_pressure", value) + @property + def toa_down_sw_radiative_flux(self) -> torch.Tensor: + return self._get("toa_down_sw_radiative_flux") + + @toa_down_sw_radiative_flux.setter + def toa_down_sw_radiative_flux(self, value: torch.Tensor): + self._set("toa_down_sw_radiative_flux", value) + + @property + def toa_up_sw_radiative_flux(self) -> torch.Tensor: + return self._get("toa_up_sw_radiative_flux") + + @toa_up_sw_radiative_flux.setter + def toa_up_sw_radiative_flux(self, value: torch.Tensor): + self._set("toa_up_sw_radiative_flux", value) + + @property + def toa_up_lw_radiative_flux(self) -> torch.Tensor: + return self._get("toa_up_lw_radiative_flux") + + @toa_up_lw_radiative_flux.setter + def toa_up_lw_radiative_flux(self, value: torch.Tensor): + self._set("toa_up_lw_radiative_flux", value) + def surface_pressure_due_to_dry_air( - self, sigma_coordinates: SigmaCoordinates + self, vertical_coordinate: HybridSigmaPressureCoordinate ) -> torch.Tensor: return metrics.surface_pressure_due_to_dry_air( self.specific_total_water, self.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, + vertical_coordinate, ) - def total_water_path(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: - return metrics.vertical_integral( + def total_water_path( + self, vertical_coordinate: HybridSigmaPressureCoordinate + ) -> torch.Tensor: + return vertical_coordinate.vertical_integral( self.specific_total_water, self.surface_pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, ) @property @@ -240,23 +224,25 @@ def tendency_of_total_water_path_due_to_advection(self, value: torch.Tensor): self._set("tendency_of_total_water_path_due_to_advection", value) def height_at_log_midpoint( - self, sigma_coordinates: SigmaCoordinates + self, vertical_coordinate: HybridSigmaPressureCoordinate ) -> torch.Tensor: """ - Compute vertical height at layer log midpoints + Compute vertical height at layer log midpoints. """ - pressure_interfaces = _pressure_at_interface( - sigma_coordinates.ak, sigma_coordinates.bk, self.surface_pressure + interface_pressure = vertical_coordinate.interface_pressure( + self.surface_pressure ) layer_thickness = _layer_thickness( - pressure_at_interface=pressure_interfaces, + pressure_at_interface=interface_pressure, air_temperature=self.air_temperature, specific_total_water=self.specific_total_water, ) height_at_interface = _height_at_interface(layer_thickness, self.surface_height) return (height_at_interface[..., :-1] * height_at_interface[..., 1:]) ** 0.5 - def moist_static_energy(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: + def moist_static_energy( + self, vertical_coordinate: HybridSigmaPressureCoordinate + ) -> torch.Tensor: """ Compute moist static energy. """ @@ -265,20 +251,22 @@ def moist_static_energy(self, sigma_coordinates: SigmaCoordinates) -> torch.Tens return ( self.air_temperature * SPECIFIC_HEAT_OF_DRY_AIR_CONST_PRESSURE + self.specific_total_water * LATENT_HEAT_OF_VAPORIZATION - + self.height_at_log_midpoint(sigma_coordinates) * GRAVITY + + self.height_at_log_midpoint(vertical_coordinate) * GRAVITY ) def compute_dry_air_absolute_differences( - climate_data: ClimateData, area: torch.Tensor, sigma_coordinates: SigmaCoordinates + climate_data: ClimateData, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + vertical_coordinate: HybridSigmaPressureCoordinate, ) -> torch.Tensor: """ Computes the absolute value of the dry air tendency of each time step. Args: climate_data: ClimateData object. - area: Area of each grid cell as a [lat, lon] tensor, in m^2. - sigma_coordinates: The sigma coordinates of the model. + area_weighted_mean: Function which returns an area-weighted mean. + vertical_coordinate: The vertical coordinate of the model. Returns: A tensor of shape (time,) of the absolute value of the dry air tendency @@ -289,21 +277,11 @@ def compute_dry_air_absolute_differences( pressure = climate_data.surface_pressure except KeyError: return torch.tensor([torch.nan]) - return ( - metrics.weighted_mean( - metrics.surface_pressure_due_to_dry_air( - water, # (sample, time, y, x, level) - pressure, - sigma_coordinates.ak, - sigma_coordinates.bk, - ), - area, - dim=(2, 3), - ) - .diff(dim=-1) - .abs() - .mean(dim=0) + ps_dry = metrics.surface_pressure_due_to_dry_air( + water, pressure, vertical_coordinate ) + ps_dry_mean = area_weighted_mean(ps_dry) + return ps_dry_mean.diff(dim=-1).abs().mean(dim=0) def _layer_thickness( @@ -314,25 +292,14 @@ def _layer_thickness( """ Computes vertical thickness of each layer assuming hydrostatic equilibrium. ACE does not currently prognose specific humidity, so here we closely - approximate this using specific total water.""" + approximate this using specific total water. + """ tv = air_temperature * (1 + (RVGAS / RDGAS - 1.0) * specific_total_water) - dlogp = torch.log(pressure_at_interface).diff(dim=-1) + # Enforce min log(p) = 0 so that geopotential energy calculation is finite + dlogp = torch.clamp(torch.log(pressure_at_interface), min=0.0).diff(dim=-1) return dlogp * RDGAS * tv / GRAVITY -def _pressure_at_interface( - ak: torch.tensor, bk: torch.tensor, surface_pressure: torch.tensor -) -> torch.Tensor: - """ - Computes pressure at layer interfaces from sigma coefficients. - Vertical coordinate is the last tensor dimension. - """ - return torch.stack( - [ak[i] + bk[i] * surface_pressure for i in range(ak.shape[-1])], - dim=-1, - ) - - def _height_at_interface( layer_thickness: torch.tensor, surface_height: torch.tensor ) -> torch.Tensor: diff --git a/fme/fme/core/coordinates.py b/fme/fme/core/coordinates.py new file mode 100644 index 0000000..8e059a9 --- /dev/null +++ b/fme/fme/core/coordinates.py @@ -0,0 +1,397 @@ +import abc +import dataclasses +from typing import List, Literal, Mapping, Optional, Tuple, TypeVar + +import numpy as np +import torch +from astropy_healpix import HEALPix + +from fme.core import metrics +from fme.core.constants import GRAVITY +from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations +from fme.core.typing_ import TensorMapping +from fme.core.winds import lon_lat_to_xyz + +HC = TypeVar("HC", bound="HorizontalCoordinates") + + +@dataclasses.dataclass +class HybridSigmaPressureCoordinate: + """ + Defines pressure at interface levels according to the following formula: + p(k) = a(k) + b(k)*ps. + + where ps is the surface pressure, a and b are the sigma-pressure coordinates. + + Parameters: + ak: a(k) coefficients as a 1-dimensional tensor + bk: b(k) coefficients as a 1-dimensional tensor + """ + + ak: torch.Tensor + bk: torch.Tensor + + def __post_init__(self): + if len(self.ak.shape) != 1: + raise ValueError( + f"ak must be a 1-dimensional tensor. Got shape: {self.ak.shape}" + ) + if len(self.bk.shape) != 1: + raise ValueError( + f"bk must be a 1-dimensional tensor. Got shape: {self.bk.shape}" + ) + if len(self.ak) != len(self.bk): + raise ValueError( + f"ak and bk must have the same length. Got len(ak)={len(self.ak)} and " + f"len(bk)={len(self.bk)}." + ) + + def __len__(self): + """The number of vertical layer interfaces.""" + return len(self.ak) + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return {"ak": self.ak.cpu().numpy(), "bk": self.bk.cpu().numpy()} + + def to(self, device: str) -> "HybridSigmaPressureCoordinate": + return HybridSigmaPressureCoordinate( + ak=self.ak.to(device), + bk=self.bk.to(device), + ) + + def __eq__(self, other) -> bool: + if not isinstance(other, HybridSigmaPressureCoordinate): + return False + return torch.allclose(self.ak, other.ak) and torch.allclose(self.bk, other.bk) + + def as_dict(self) -> TensorMapping: + return {"ak": self.ak, "bk": self.bk} + + def interface_pressure(self, surface_pressure: torch.Tensor) -> torch.Tensor: + """ + Compute pressure at vertical layer interfaces. + + Args: + surface_pressure: The surface pressure in units of Pa. + + Returns: + A tensor of pressure at vertical layer interfaces. Will contain a new + dimension at the end, representing the vertical. + """ + return torch.stack( + [ak + bk * surface_pressure for ak, bk in zip(self.ak, self.bk)], + dim=-1, + ) + + def vertical_integral( + self, integrand: torch.Tensor, surface_pressure: torch.Tensor + ) -> torch.Tensor: + """ + Compute the mass-weighted vertical integral of the integrand. + + (1 / g) * ∫ x dp + + where + - g = acceleration due to gravity + - x = integrad + - p = pressure level + + Args: + surface_pressure: The surface pressure in units of Pa. + integrand: A tensor whose last dimension is the vertical. + + Returns: + A tensor of same shape as integrand but without the last dimension. + """ + if len(self.ak) != integrand.shape[-1] + 1: + raise ValueError( + "The last dimension of integrand must match the number of vertical " + "layers in the hybrid sigma-pressure vertical coordinate." + ) + interface_pressure = self.interface_pressure(surface_pressure) + pressure_thickness = interface_pressure.diff(dim=-1) + return (integrand * pressure_thickness).sum(dim=-1) / GRAVITY + + +@dataclasses.dataclass +class DimSize: + name: str + size: int + + +class HorizontalCoordinates(abc.ABC): + """ + Parent class for horizontal coordinate system grids. + Contains coords which must be subclassed to provide the coordinates. + """ + + @abc.abstractmethod + def __eq__(self, other) -> bool: + pass + + @abc.abstractmethod + def to(self: HC, device: str) -> HC: + pass + + @property + @abc.abstractmethod + def coords(self) -> Mapping[str, np.ndarray]: + pass + + @property + @abc.abstractmethod + def xyz(self) -> Tuple[float, float, float]: + pass + + @property + @abc.abstractmethod + def dims(self) -> List[str]: + """Names of model horizontal dimensions.""" + pass + + @property + @abc.abstractmethod + def loaded_dims(self) -> List[str]: + """Names of horizontal dimensions as loaded from training dataset.""" + pass + + @property + @abc.abstractmethod + def loaded_sizes(self) -> List[DimSize]: + """Sizes of horizontal dimensions as loaded from training dataset.""" + pass + + @property + @abc.abstractmethod + def loaded_default_sizes(self) -> List[DimSize]: + """Default sizes of horizontal data dimensions, used by testing code.""" + pass + + @property + @abc.abstractmethod + def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]: + pass + + # A temporary solution for training which allows us to aggregate along the + # latitude dimension. + # TODO: https://github.com/ai2cm/full-model/issues/1003 + @abc.abstractmethod + def get_lat(self) -> torch.Tensor: + pass + + @property + @abc.abstractmethod + def area_weights(self) -> Optional[torch.Tensor]: + pass + + @property + @abc.abstractmethod + def gridded_operations(self) -> GriddedOperations: + pass + + @property + @abc.abstractmethod + def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Meshgrids of latitudes and longitudes, respectively.""" + pass + + +@dataclasses.dataclass +class LatLonCoordinates(HorizontalCoordinates): + """ + Defines a (latitude, longitude) grid. + + Parameters: + lat: 1-dimensional tensor of latitudes + lon: 1-dimensional tensor of longitudes + loaded_lat_name: name of the latitude dimension + as loaded from training dataset + loaded_lon_name: name of the longitude dimension + as loaded from training dataset + """ + + lon: torch.Tensor + lat: torch.Tensor + loaded_lat_name: str = "lat" + loaded_lon_name: str = "lon" + + def __post_init__(self): + self._area_weights: Optional[torch.Tensor] = None + + def __eq__(self, other) -> bool: + if not isinstance(other, LatLonCoordinates): + return False + return ( + torch.allclose(self.lat, other.lat) + and torch.allclose(self.lon, other.lon) + and self.loaded_lat_name == other.loaded_lat_name + and self.loaded_lon_name == other.loaded_lon_name + ) + + def to(self, device: str) -> "LatLonCoordinates": + return LatLonCoordinates( + lon=self.lon.to(device), + lat=self.lat.to(device), + loaded_lat_name=self.loaded_lat_name, + loaded_lon_name=self.loaded_lon_name, + ) + + @property + def area_weights(self) -> torch.Tensor: + if self._area_weights is None: + self._area_weights = metrics.spherical_area_weights(self.lat, len(self.lon)) + return self._area_weights + + @property + def coords(self) -> Mapping[str, np.ndarray]: + # TODO: Replace with lat/lon name? + return { + "lat": self.lat.cpu().type(torch.float32).numpy(), + "lon": self.lon.cpu().type(torch.float32).numpy(), + } + + @property + def xyz(self) -> Tuple[float, float, float]: + lats, lons = np.broadcast_arrays( + self.coords["lat"][:, None], self.coords["lon"][None, :] + ) + return lon_lat_to_xyz(lons, lats) + + def get_lat(self) -> torch.Tensor: + return self.lat + + @property + def dims(self) -> List[str]: + return ["lat", "lon"] + + @property + def loaded_dims(self) -> List[str]: + return [self.loaded_lat_name, self.loaded_lon_name] + + @property + def loaded_sizes(self) -> List[DimSize]: + return [ + DimSize(self.loaded_lat_name, len(self.lat)), + DimSize(self.loaded_lon_name, len(self.lon)), + ] + + @property + def loaded_default_sizes(self) -> List[DimSize]: + return [DimSize(self.loaded_lat_name, 16), DimSize(self.loaded_lon_name, 32)] + + @property + def grid(self) -> Literal["equiangular", "legendre-gauss"]: + if torch.allclose( + self.lat[1:] - self.lat[:-1], + self.lat[1] - self.lat[0], + ): + return "equiangular" + else: + return "legendre-gauss" + + @property + def gridded_operations(self) -> LatLonOperations: + return LatLonOperations(self.area_weights) + + @property + def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.meshgrid(self.lat, self.lon, indexing="ij") + + +@dataclasses.dataclass +class HEALPixCoordinates(HorizontalCoordinates): + """ + Defines a HEALPix (face, height, width) grid. See https://healpix.jpl.nasa.gov/ for + more information. + + Parameters: + face: 1-dimensional tensor of faces + height: 1-dimensional tensor of heights + width: 1-dimensional tensor of widths + """ + + face: torch.Tensor + height: torch.Tensor + width: torch.Tensor + + def __eq__(self, other) -> bool: + if not isinstance(other, HEALPixCoordinates): + return False + return ( + torch.allclose(self.face, other.face) + and torch.allclose(self.height, other.height) + and torch.allclose(self.width, other.width) + ) + + def to(self, device: str) -> "HEALPixCoordinates": + return HEALPixCoordinates( + face=self.face.to(device), + height=self.height.to(device), + width=self.width.to(device), + ) + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return { + "face": self.face.cpu().type(torch.float32).numpy(), + "height": self.height.cpu().type(torch.float32).numpy(), + "width": self.width.cpu().type(torch.float32).numpy(), + } + + @property + def xyz(self) -> Tuple[float, float, float]: + hp = HEALPix(nside=len(self.height), order="ring") + return hp.healpix_to_xyz( + [self.coords["face"], self.coords["height"], self.coords["width"]] + ) + + @property + def dims(self) -> List[str]: + return ["face", "height", "width"] + + @property + def loaded_dims(self) -> List[str]: + return self.dims + + @property + def loaded_sizes(self) -> List[DimSize]: + return [ + DimSize("face", len(self.face)), + DimSize("height", len(self.width)), + DimSize("width", len(self.height)), + ] + + @property + def loaded_default_sizes(cls) -> List[DimSize]: + return [ + DimSize("face", 12), + DimSize("height", 16), + DimSize("width", 16), + ] + + # TODO: https://github.com/ai2cm/full-model/issues/1003 + # This is currently the dummy solution. + def get_lat(self) -> torch.Tensor: + raise NotImplementedError( + "healpix does not support get_lat. If latitude is needed \ + for some reason, you may use this class's self.xyz property to derive it." + ) + + @property + def grid(self) -> Literal["healpix"]: + return "healpix" + + @property + def area_weights(self) -> Literal[None]: + return None + + @property + def gridded_operations(self) -> HEALPixOperations: + return HEALPixOperations() + + @property + def meshgrid(self) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError( + "meshgrid is not implemented yet for HEALPixCoordinates." + ) diff --git a/fme/fme/core/corrector/__init__.py b/fme/fme/core/corrector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fme/fme/core/corrector.py b/fme/fme/core/corrector/corrector.py similarity index 67% rename from fme/fme/core/corrector.py rename to fme/fme/core/corrector/corrector.py index 1c01ab1..bf3f72f 100644 --- a/fme/fme/core/corrector.py +++ b/fme/fme/core/corrector/corrector.py @@ -1,17 +1,22 @@ import dataclasses import datetime -from typing import List, Literal, Optional +from typing import Any, Callable, List, Literal, Mapping, Optional, Protocol +import dacite import torch -from fme.core import metrics +import fme from fme.core.climate_data import ClimateData -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol +from fme.core.gridded_ops import GriddedOperations +from fme.core.registry.corrector import CorrectorSelector from fme.core.typing_ import TensorDict, TensorMapping +@CorrectorSelector.register("atmosphere_corrector") @dataclasses.dataclass -class CorrectorConfig: +class CorrectorConfig(CorrectorConfigProtocol): r""" Configuration for the post-step state corrector. @@ -57,7 +62,7 @@ class CorrectorConfig: advection is zero. Therefore ``zero_global_mean_moisture_advection`` must be True if using a ``moisture_budget_correction`` option other than ``None``. - Attributes: + Parameters: conserve_dry_air: If True, force the generated data to conserve dry air by subtracting a constant offset from the surface pressure of each column. This can cause changes in per-mass values such as total water @@ -67,19 +72,21 @@ class CorrectorConfig: offset from the moisture advection tendency of each column. moisture_budget_correction: If not "None", force the generated data to conserve global or column-local moisture by modifying budget fields. - Options include: - - "precipitation": multiply precipitation by a scale factor - to close the global moisture budget. - - "evaporation": multiply evaporation by a scale factor - to close the global moisture budget. - - "advection_and_precipitation": after applying the "precipitation" - global-mean correction above, recompute the column-integrated - advective tendency as the budget residual, - ensuring column budget closure. - - "advection_and_evaporation": after applying the "evaporation" - global-mean correction above, recompute the column-integrated - advective tendency as the budget residual, - ensuring column budget closure. + Options are: + + - ``precipitation``: multiply precipitation by a scale factor + to close the global moisture budget. + - ``evaporation``: multiply evaporation by a scale factor + to close the global moisture budget. + - ``advection_and_precipitation``: after applying the "precipitation" + global-mean correction above, recompute the column-integrated + advective tendency as the budget residual, + ensuring column budget closure. + - ``advection_and_evaporation``: after applying the "evaporation" + global-mean correction above, recompute the column-integrated + advective tendency as the budget residual, + ensuring column budget closure. + force_positive_names: Names of fields that should be forced to be greater than or equal to zero. This is useful for fields like precipitation. """ @@ -98,65 +105,87 @@ class CorrectorConfig: def build( self, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, - ) -> Optional["Corrector"]: + ) -> "Corrector": return Corrector( config=self, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, timestep=timestep, ) + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "CorrectorConfig": + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + -class Corrector: +class Corrector(CorrectorABC): def __init__( self, config: CorrectorConfig, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ): self._config = config - self._area = area - self._sigma_coordinates = sigma_coordinates + self._gridded_operations = gridded_operations + self._vertical_coordinates = vertical_coordinate self._timestep = timestep + if fme.get_device() == torch.device("mps", 0): + self._dry_air_precision = torch.float32 + else: + self._dry_air_precision = torch.float64 def __call__( self, input_data: TensorMapping, gen_data: TensorMapping, - ): + forcing_data: TensorMapping, + ) -> TensorMapping: + """Apply corrections to the generated data. + + Args: + input_data: The input time step data. + gen_data: The data generated by the model, to be corrected. + forcing_data: The forcing data for the same time step as gen_data. + + Returns: + The corrected data. + """ if len(self._config.force_positive_names) > 0: # do this step before imposing other conservation correctors, since # otherwise it could end up creating violations of those constraints. - gen_data = _force_positive(gen_data, self._config.force_positive_names) + gen_data = force_positive(gen_data, self._config.force_positive_names) if self._config.conserve_dry_air: gen_data = _force_conserve_dry_air( input_data=input_data, gen_data=gen_data, - area=self._area, - sigma_coordinates=self._sigma_coordinates, + area_weighted_mean=self._gridded_operations.area_weighted_mean, + vertical_coordinate=self._vertical_coordinates, + precision=self._dry_air_precision, ) if self._config.zero_global_mean_moisture_advection: gen_data = _force_zero_global_mean_moisture_advection( gen_data=gen_data, - area=self._area, + area_weighted_mean=self._gridded_operations.area_weighted_mean, ) if self._config.moisture_budget_correction is not None: gen_data = _force_conserve_moisture( input_data=input_data, gen_data=gen_data, - area=self._area, - sigma_coordinates=self._sigma_coordinates, + area_weighted_mean=self._gridded_operations.area_weighted_mean, + vertical_coordinate=self._vertical_coordinates, timestep=self._timestep, terms_to_modify=self._config.moisture_budget_correction, ) return gen_data -def _force_positive(data: TensorMapping, names: List[str]) -> TensorDict: +def force_positive(data: TensorMapping, names: List[str]) -> TensorDict: """Clamp all tensors defined by `names` to be greater than or equal to zero.""" out = {**data} for name in names: @@ -164,11 +193,16 @@ def _force_positive(data: TensorMapping, names: List[str]) -> TensorDict: return out +class AreaWeightedMean(Protocol): + def __call__(self, data: torch.Tensor, keepdim: bool) -> torch.Tensor: ... + + def _force_conserve_dry_air( input_data: TensorMapping, gen_data: TensorMapping, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, + area_weighted_mean: AreaWeightedMean, + vertical_coordinate: HybridSigmaPressureCoordinate, + precision: torch.dtype = torch.float64, ) -> TensorDict: """ Update the generated data to conserve dry air. @@ -199,21 +233,20 @@ def _force_conserve_dry_air( if input.surface_pressure is None: raise ValueError("surface_pressure is required to force dry air conservation") gen = ClimateData(gen_data) - gen_dry_air = gen.surface_pressure_due_to_dry_air(sigma_coordinates) - global_gen_dry_air = metrics.weighted_mean(gen_dry_air, weights=area, dim=(-2, -1)) - global_target_gen_dry_air = metrics.weighted_mean( - input.surface_pressure_due_to_dry_air(sigma_coordinates), - weights=area, - dim=(-2, -1), + gen_dry_air = gen.surface_pressure_due_to_dry_air(vertical_coordinate) + global_gen_dry_air = area_weighted_mean(gen_dry_air.to(precision), keepdim=True) + global_target_gen_dry_air = area_weighted_mean( + input.surface_pressure_due_to_dry_air(vertical_coordinate).to(precision), + keepdim=True, ) error = global_gen_dry_air - global_target_gen_dry_air - new_gen_dry_air = gen_dry_air - error[..., None, None] + new_gen_dry_air = gen_dry_air.to(precision) - error try: - wat = gen.specific_total_water + wat = gen.specific_total_water.to(precision) except KeyError: raise ValueError("specific_total_water is required for conservation") - ak_diff = sigma_coordinates.ak.diff() - bk_diff = sigma_coordinates.bk.diff() + ak_diff = vertical_coordinate.ak.diff().to(precision) + bk_diff = vertical_coordinate.bk.diff().to(precision) new_pressure = (new_gen_dry_air + (ak_diff * wat).sum(-1)) / ( 1 - (bk_diff * wat).sum(-1) ) @@ -223,7 +256,7 @@ def _force_conserve_dry_air( def _force_zero_global_mean_moisture_advection( gen_data: TensorMapping, - area: torch.Tensor, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], ) -> TensorDict: """ Update the generated data so advection conserves moisture. @@ -232,15 +265,13 @@ def _force_zero_global_mean_moisture_advection( Args: gen_data: The generated data. - area: (n_lat, n_lon) array containing relative gridcell area, in any - units including unitless. + area_weighted_mean: Computes an area-weighted mean, + removing horizontal dimensions. """ gen = ClimateData(gen_data) - mean_moisture_advection = metrics.weighted_mean( + mean_moisture_advection = area_weighted_mean( gen.tendency_of_total_water_path_due_to_advection, - weights=area, - dim=(-2, -1), ) gen.tendency_of_total_water_path_due_to_advection = ( gen.tendency_of_total_water_path_due_to_advection @@ -252,8 +283,8 @@ def _force_zero_global_mean_moisture_advection( def _force_conserve_moisture( input_data: TensorMapping, gen_data: TensorMapping, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, + area_weighted_mean: AreaWeightedMean, + vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, terms_to_modify: Literal[ "precipitation", @@ -274,9 +305,9 @@ def _force_conserve_moisture( Args: input_data: The input data. gen_data: The generated data one timestep after the input data. - area: (n_lat, n_lon) array containing relative gridcell area, in any - units including unitless. - sigma_coordinates: The sigma coordinates. + area_weighted_mean: Computes an area-weighted mean, + removing horizontal dimensions. + vertical_coordinate: The sigma coordinates. timestep: Timestep of the model. terms_to_modify: Which terms to modify, in addition to modifying surface pressure to conserve dry air mass. One of: @@ -288,20 +319,14 @@ def _force_conserve_moisture( input = ClimateData(input_data) gen = ClimateData(gen_data) - gen_total_water_path = gen.total_water_path(sigma_coordinates) + gen_total_water_path = gen.total_water_path(vertical_coordinate) timestep_seconds = timestep / datetime.timedelta(seconds=1) twp_total_tendency = ( - gen_total_water_path - input.total_water_path(sigma_coordinates) + gen_total_water_path - input.total_water_path(vertical_coordinate) ) / timestep_seconds - twp_tendency_global_mean = metrics.weighted_mean( - twp_total_tendency, weights=area, dim=(-2, -1) - ) - evaporation_global_mean = metrics.weighted_mean( - gen.evaporation_rate, weights=area, dim=(-2, -1) - ) - precipitation_global_mean = metrics.weighted_mean( - gen.precipitation_rate, weights=area, dim=(-2, -1) - ) + twp_tendency_global_mean = area_weighted_mean(twp_total_tendency, keepdim=True) + evaporation_global_mean = area_weighted_mean(gen.evaporation_rate, keepdim=True) + precipitation_global_mean = area_weighted_mean(gen.precipitation_rate, keepdim=True) if terms_to_modify.endswith("precipitation"): # We want to achieve # global_mean(twp_total_tendency) = ( @@ -324,20 +349,16 @@ def _force_conserve_moisture( # new_precip_rate = ( # new_global_precip_rate / current_global_precip_rate # ) * current_precip_rate - gen.precipitation_rate = ( - gen.precipitation_rate - * (new_precipitation_global_mean / precipitation_global_mean)[ - ..., None, None - ] + gen.precipitation_rate = gen.precipitation_rate * ( + new_precipitation_global_mean / precipitation_global_mean ) elif terms_to_modify.endswith("evaporation"): # Derived similarly as for "precipitation" case. new_evaporation_global_mean = ( twp_tendency_global_mean + precipitation_global_mean ) - gen.evaporation_rate = ( - gen.evaporation_rate - * (new_evaporation_global_mean / evaporation_global_mean)[..., None, None] + gen.evaporation_rate = gen.evaporation_rate * ( + new_evaporation_global_mean / evaporation_global_mean ) if terms_to_modify.startswith("advection"): # Having already corrected the global-mean budget, we recompute diff --git a/fme/fme/core/corrector/ocean.py b/fme/fme/core/corrector/ocean.py new file mode 100644 index 0000000..b890090 --- /dev/null +++ b/fme/fme/core/corrector/ocean.py @@ -0,0 +1,83 @@ +import dataclasses +import datetime +from types import MappingProxyType +from typing import Any, List, Mapping, Optional + +import dacite + +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.corrector.corrector import force_positive +from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol +from fme.core.gridded_ops import GriddedOperations +from fme.core.masking import MaskingConfig +from fme.core.registry.corrector import CorrectorSelector +from fme.core.stacker import Stacker +from fme.core.typing_ import TensorMapping + +OCEAN_FIELD_NAME_PREFIXES = MappingProxyType( + { + "surface_height": ["zos"], + "salinity": ["so_"], + "potential_temperature": ["thetao_"], + "zonal_velocity": ["uo_"], + "meridional_velocity": ["vo_"], + } +) + + +@CorrectorSelector.register("ocean_corrector") +@dataclasses.dataclass +class OceanCorrectorConfig(CorrectorConfigProtocol): + masking: Optional[MaskingConfig] = None + force_positive_names: List[str] = dataclasses.field(default_factory=list) + + def build( + self, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ): + return OceanCorrector( + config=self, + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + ) + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig": + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + +class OceanCorrector(CorrectorABC): + def __init__( + self, + config: OceanCorrectorConfig, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ): + self._config = config + self._gridded_operations = gridded_operations + self._vertical_coordinates = vertical_coordinate + self._timestep = timestep + + if config.masking is not None: + self._masking = config.masking.build() + else: + self._masking = None + self._stacker = Stacker(OCEAN_FIELD_NAME_PREFIXES) + + def __call__( + self, + input_data: TensorMapping, + gen_data: TensorMapping, + forcing_data: TensorMapping, + ) -> TensorMapping: + if self._masking is not None: + gen_data = self._masking(self._stacker, gen_data, input_data) + if len(self._config.force_positive_names) > 0: + gen_data = force_positive(gen_data, self._config.force_positive_names) + return gen_data diff --git a/fme/fme/core/corrector/registry.py b/fme/fme/core/corrector/registry.py new file mode 100644 index 0000000..e0efef1 --- /dev/null +++ b/fme/fme/core/corrector/registry.py @@ -0,0 +1,34 @@ +import abc +import datetime +from typing import Any, Mapping, Protocol + +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.gridded_ops import GriddedOperations +from fme.core.typing_ import TensorMapping + + +class CorrectorConfigProtocol(Protocol): + def build( + self, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ) -> "CorrectorABC": ... + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "CorrectorConfigProtocol": + """ + Create a ModuleSelector from a dictionary containing all the information + needed to build a ModuleConfig. + """ + ... + + +class CorrectorABC(abc.ABC): + @abc.abstractmethod + def __call__( + self, + input_data: TensorMapping, + gen_data: TensorMapping, + forcing_data: TensorMapping, + ) -> TensorMapping: ... diff --git a/fme/fme/core/corrector/test_corrector.py b/fme/fme/core/corrector/test_corrector.py new file mode 100644 index 0000000..25b67f0 --- /dev/null +++ b/fme/fme/core/corrector/test_corrector.py @@ -0,0 +1,260 @@ +import datetime +from typing import Callable, Optional, Tuple + +import numpy as np +import pytest +import torch + +from fme.ace.inference.derived_variables import total_water_path_budget_residual +from fme.core import ClimateData, metrics +from fme.core.climate_data import compute_dry_air_absolute_differences +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.corrector.ocean import OceanCorrector +from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations +from fme.core.registry.corrector import CorrectorSelector +from fme.core.typing_ import TensorMapping + +from .corrector import ( + _force_conserve_dry_air, + _force_conserve_moisture, + _force_zero_global_mean_moisture_advection, + force_positive, +) + +TIMESTEP = datetime.timedelta(hours=6) + + +def get_dry_air_nonconservation( + data: TensorMapping, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + vertical_coordinate: HybridSigmaPressureCoordinate, +): + """ + Computes the time-average one-step absolute difference in surface pressure due to + changes in globally integrated dry air. + + Args: + data: A mapping from variable name to tensor of shape + [sample, time, lat, lon], in physical units. specific_total_water in kg/kg + and surface_pressure in Pa must be present. + area_weighted_mean: Computes the area-weighted mean of a tensor, removing the + horizontal dimensions. + vertical_coordinate: The vertical coordinates of the model. + """ + return compute_dry_air_absolute_differences( + ClimateData(data), + area_weighted_mean=area_weighted_mean, + vertical_coordinate=vertical_coordinate, + ).mean() + + +def test_force_no_global_mean_moisture_advection(): + torch.random.manual_seed(0) + data = { + "tendency_of_total_water_path_due_to_advection": torch.rand(size=(3, 2, 5, 5)), + } + area_weights = 1.0 + torch.rand(size=(5, 5)) + original_mean = metrics.weighted_mean( + data["tendency_of_total_water_path_due_to_advection"], + weights=area_weights, + dim=[-2, -1], + ) + assert (original_mean.abs() > 0.0).all() + fixed_data = _force_zero_global_mean_moisture_advection( + data, + area_weighted_mean=LatLonOperations(area_weights).area_weighted_mean, + ) + new_mean = metrics.weighted_mean( + fixed_data["tendency_of_total_water_path_due_to_advection"], + weights=area_weights, + dim=[-2, -1], + ) + assert (new_mean.abs() < original_mean.abs()).all() + np.testing.assert_almost_equal(new_mean.cpu().numpy(), 0.0, decimal=6) + + +@pytest.mark.parametrize( + "size, use_area", + [ + pytest.param((3, 2, 5, 5), True, id="latlon"), + pytest.param((3, 12, 2, 3, 3), False, id="healpix"), + ], +) +def test_force_conserve_dry_air(size: Tuple[int, ...], use_area: bool): + torch.random.manual_seed(0) + data = { + "PRESsfc": 10.0 + torch.rand(size=size), + "specific_total_water_0": torch.rand(size=size), + "specific_total_water_1": torch.rand(size=size), + } + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0]) + ) + if use_area: + area_weights: Optional[torch.Tensor] = 1.0 + torch.rand( + size=(size[-2], size[-1]) + ) + else: + area_weights = None + if area_weights is not None: + gridded_operations: GriddedOperations = LatLonOperations(area_weights) + else: + gridded_operations = HEALPixOperations() + original_nonconservation = get_dry_air_nonconservation( + data, + vertical_coordinate=vertical_coordinate, + area_weighted_mean=gridded_operations.area_weighted_mean, + ) + assert original_nonconservation > 0.0 + in_data = {k: v.select(dim=1, index=0) for k, v in data.items()} + out_data = {k: v.select(dim=1, index=1) for k, v in data.items()} + fixed_out_data = _force_conserve_dry_air( + in_data, + out_data, + vertical_coordinate=vertical_coordinate, + area_weighted_mean=gridded_operations.area_weighted_mean, + ) + new_data = { + k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items() + } + new_nonconservation = get_dry_air_nonconservation( + new_data, + vertical_coordinate=vertical_coordinate, + area_weighted_mean=gridded_operations.area_weighted_mean, + ) + assert new_nonconservation < original_nonconservation + np.testing.assert_almost_equal(new_nonconservation.cpu().numpy(), 0.0, decimal=6) + + +@pytest.mark.parametrize("dataset", ["fv3", "e3sm"]) +@pytest.mark.parametrize( + "global_only, terms_to_modify", + [ + (True, "precipitation"), + (True, "evaporation"), + (False, "advection_and_precipitation"), + (False, "advection_and_evaporation"), + ], +) +@pytest.mark.parametrize( + "size, use_area", + [ + pytest.param((3, 2, 5, 5), True, id="latlon"), + pytest.param((3, 12, 2, 3, 3), False, id="healpix"), + ], +) +def test_force_conserve_moisture( + dataset: str, + global_only: bool, + terms_to_modify, + size: Tuple[int, ...], + use_area: bool, +): + torch.random.manual_seed(0) + if dataset == "fv3": + data = { + "PRESsfc": 10.0 + torch.rand(size=size), + "specific_total_water_0": torch.rand(size=size), + "specific_total_water_1": torch.rand(size=size), + "PRATEsfc": torch.rand(size=size), + "LHTFLsfc": torch.rand(size=size), + "tendency_of_total_water_path_due_to_advection": torch.rand(size=size), + } + if dataset == "e3sm": + data = { + "PS": 10.0 + torch.rand(size=size), + "specific_total_water_0": torch.rand(size=size), + "specific_total_water_1": torch.rand(size=size), + "surface_precipitation_rate": torch.rand(size=size), + "LHFLX": torch.rand(size=size), + "tendency_of_total_water_path_due_to_advection": torch.rand(size=size), + } + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0]) + ) + if use_area: + ops: GriddedOperations = LatLonOperations(1.0 + torch.rand(size=(5, 5))) + else: + ops = HEALPixOperations() + data["tendency_of_total_water_path_due_to_advection"] -= ops.area_weighted_mean( + data["tendency_of_total_water_path_due_to_advection"], keepdim=True + ) + original_budget_residual = total_water_path_budget_residual( + ClimateData(data), + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + )[:, 1] # no meaning for initial value data, want first timestep + if global_only: + original_budget_residual = ops.area_weighted_mean( + original_budget_residual, keepdim=True + ) + original_budget_residual = original_budget_residual.cpu().numpy() + original_dry_air = ( + ClimateData(data) + .surface_pressure_due_to_dry_air(vertical_coordinate) + .cpu() + .numpy() + ) + assert np.any(np.abs(original_budget_residual) > 0.0) + in_data = {k: v.select(dim=1, index=0) for k, v in data.items()} + out_data = {k: v.select(dim=1, index=1) for k, v in data.items()} + fixed_out_data = _force_conserve_moisture( + in_data, + out_data, + vertical_coordinate=vertical_coordinate, + area_weighted_mean=ops.area_weighted_mean, + timestep=TIMESTEP, + terms_to_modify=terms_to_modify, + ) + new_data = { + k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items() + } + new_budget_residual = total_water_path_budget_residual( + ClimateData(new_data), + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + )[:, 1] # no meaning for initial value data, want first timestep + new_dry_air = ( + ClimateData(data) + .surface_pressure_due_to_dry_air(vertical_coordinate) + .cpu() + .numpy() + ) + + global_budget_residual = ops.area_weighted_mean(new_budget_residual).cpu().numpy() + np.testing.assert_almost_equal(global_budget_residual, 0.0, decimal=6) + + if not global_only: + new_budget_residual = new_budget_residual.cpu().numpy() + assert np.all(np.abs(new_budget_residual) < np.abs(original_budget_residual)) + np.testing.assert_almost_equal(new_budget_residual, 0.0, decimal=6) + + np.testing.assert_almost_equal(new_dry_air, original_dry_air, decimal=6) + + +def test_force_positive(): + data = { + "foo": torch.tensor([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]), + "bar": torch.tensor([[-1.0, 0.0], [0.0, -3.0], [1.0, 2.0]]), + } + original_min = torch.min(data["foo"]) + assert original_min < 0.0 + fixed_data = force_positive(data, ["foo"]) + new_min = torch.min(fixed_data["foo"]) + # Ensure the minimum value of 'foo' is now 0 + torch.testing.assert_close(new_min, torch.tensor(0.0)) + # Ensure other variables are not modified + torch.testing.assert_close(fixed_data["bar"], data["bar"]) + + +def test_corrector_selector(): + selector = CorrectorSelector( + type="ocean_corrector", + config={"masking": {"mask_name": "mask", "mask_value": 1}}, + ) + ops: GriddedOperations = LatLonOperations(1.0 + torch.rand(size=(5, 5))) + vertical: HybridSigmaPressureCoordinate = HybridSigmaPressureCoordinate( + ak=torch.tensor([1.0, 0.5, 0.0]), bk=torch.tensor([0.0, 0.5, 1.0]) + ) + corrector = selector.build(ops, vertical, TIMESTEP) + assert isinstance(corrector, OceanCorrector) diff --git a/fme/fme/core/data_loading/config.py b/fme/fme/core/data_loading/config.py deleted file mode 100644 index 735cb40..0000000 --- a/fme/fme/core/data_loading/config.py +++ /dev/null @@ -1,122 +0,0 @@ -import dataclasses -from typing import Literal, Optional, Sequence, Union - -import xarray as xr - -from fme.core.distributed import Distributed - - -@dataclasses.dataclass -class Slice: - """ - Configuration of a python `slice` built-in. - - Required because `slice` cannot be initialized directly by dacite. - - Attributes: - start: Start index of the slice. - stop: Stop index of the slice. - step: Step of the slice. - """ - - start: Optional[int] = None - stop: Optional[int] = None - step: Optional[int] = None - - @property - def slice(self) -> slice: - return slice(self.start, self.stop, self.step) - - -@dataclasses.dataclass -class TimeSlice: - """ - Configuration of a slice of times. Step is an integer-valued index step. - - Note: start_time and stop_time may be provided as partial time strings and the - stop_time will be included in the slice. See more details in `Xarray docs`_. - - Attributes: - start_time: Start time of the slice. - stop_time: Stop time of the slice. - step: Step of the slice. - - .. _Xarray docs: - https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range # noqa - """ - - start_time: Optional[str] = None - stop_time: Optional[str] = None - step: Optional[int] = None - - def slice(self, times: xr.CFTimeIndex) -> slice: - return times.slice_indexer(self.start_time, self.stop_time, self.step) - - -@dataclasses.dataclass -class XarrayDataConfig: - """ - Attributes: - data_path: Path to the data. - file_pattern: Glob pattern to match files in the data_path. - n_repeats: Number of times to repeat the dataset (in time). It is up - to the user to ensure that the input dataset to repeat results in - data that is reasonably continuous across repetitions. - engine: Backend for xarray.open_dataset. Currently supported options - are "netcdf4" (the default) and "h5netcdf". Only valid when using - XarrayDataset. - spatial_dimensions: Specifies the spatial dimensions for the grid, default - is lat/lon. - subset: Slice defining a subset of the XarrayDataset to load. This can - either be a `Slice` of integer indices or a `TimeSlice` of timestamps. - infer_timestep: Whether to infer the timestep from the provided data. - This should be set to True (the default) for ACE training. It may - be useful to toggle this to False for applications like downscaling, - which do not depend on the timestep of the data and therefore lack - the additional requirement that the data be ordered and evenly - spaced in time. It must be set to True if n_repeats > 1 in order - to be able to infer the full time coordinate. - """ - - data_path: str - file_pattern: str = "*.nc" - n_repeats: int = 1 - engine: Optional[Literal["netcdf4", "h5netcdf", "zarr"]] = None - spatial_dimensions: Literal["healpix", "latlon"] = "latlon" - subset: Union[Slice, TimeSlice] = dataclasses.field(default_factory=Slice) - infer_timestep: bool = True - - def __post_init__(self): - if self.n_repeats > 1 and not self.infer_timestep: - raise ValueError( - "infer_timestep must be True if n_repeats is greater than 1" - ) - - -@dataclasses.dataclass -class DataLoaderConfig: - """ - Attributes: - dataset: A sequence of configurations each defining a dataset - to be loaded. This sequence of datasets will be concatenated. - batch_size: Number of samples per batch. - num_data_workers: Number of parallel workers to use for data loading. - prefetch_factor: how many batches a single data worker will attempt to - hold in host memory at a given time. - strict_ensemble: Whether to enforce that the ensemble members have the same - dimensions and coordinates. - """ - - dataset: Sequence[XarrayDataConfig] - batch_size: int - num_data_workers: int - prefetch_factor: Optional[int] = None - strict_ensemble: bool = True - - def __post_init__(self): - dist = Distributed.get_instance() - if self.batch_size % dist.world_size != 0: - raise ValueError( - "batch_size must be divisible by the number of parallel " - f"workers, got {self.batch_size} and {dist.world_size}" - ) diff --git a/fme/fme/core/data_loading/data_typing.py b/fme/fme/core/data_loading/data_typing.py deleted file mode 100644 index a039db1..0000000 --- a/fme/fme/core/data_loading/data_typing.py +++ /dev/null @@ -1,276 +0,0 @@ -import abc -import dataclasses -import datetime -from collections import namedtuple -from typing import Dict, List, Literal, Mapping, Optional, Tuple - -import numpy as np -import torch -import xarray as xr -from astropy_healpix import HEALPix - -from fme.core.typing_ import TensorDict, TensorMapping -from fme.core.winds import lon_lat_to_xyz - -VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"]) - - -@dataclasses.dataclass -class SigmaCoordinates: - """ - Defines pressure at interface levels according to the following formula: - p(k) = a(k) + b(k)*ps - - where ps is the surface pressure, a and b are the sigma coordinates. - - Attributes: - ak: a(k) coefficients as a 1-dimensional tensor - bk: b(k) coefficients as a 1-dimensional tensor - """ - - ak: torch.Tensor - bk: torch.Tensor - - @property - def coords(self) -> Mapping[str, np.ndarray]: - return {"ak": self.ak.cpu().numpy(), "bk": self.bk.cpu().numpy()} - - def to(self, device: str) -> "SigmaCoordinates": - return SigmaCoordinates( - ak=self.ak.to(device), - bk=self.bk.to(device), - ) - - def as_dict(self) -> TensorMapping: - return {"ak": self.ak, "bk": self.bk} - - -@dataclasses.dataclass -class HorizontalCoordinates(abc.ABC): - """ - Parent class for horizontal coordinate system grids. - Contains coords which must be subclassed to provide the coordinates. - """ - - @property - @abc.abstractmethod - def coords(self) -> Mapping[str, np.ndarray]: - pass - - @property - @abc.abstractmethod - def xyz(self) -> Tuple[float, float, float]: - pass - - @property - @abc.abstractmethod - def dims(self) -> List[str]: - pass - - @property - @abc.abstractmethod - def default_sizes(self) -> Dict[str, int]: - pass - - @property - @abc.abstractmethod - def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]: - pass - - # A temporary solution for training which allows us to aggregate along the - # latitude dimension. - # TODO: https://github.com/ai2cm/full-model/issues/1003 - @abc.abstractmethod - def get_lat(self) -> torch.Tensor: - pass - - -@dataclasses.dataclass -class LatLonCoordinates(HorizontalCoordinates): - """ - Defines a (latitude, longitude) grid. - - Attributes: - lat: 1-dimensional tensor of latitudes - lon: 1-dimensional tensor of longitudes - """ - - lon: torch.Tensor - lat: torch.Tensor - - lat_name: str = "lat" - lon_name: str = "lon" - - @property - def coords(self) -> Mapping[str, np.ndarray]: - # TODO: Replace with lat/lon name? - return { - "lat": self.lat.cpu().numpy(), - "lon": self.lon.cpu().numpy(), - } - - @property - def xyz(self) -> Tuple[float, float, float]: - lats, lons = np.broadcast_arrays(self.lat[:, None], self.lon[None, :]) - return lon_lat_to_xyz(lons, lats) - - # TODO: https://github.com/ai2cm/full-model/issues/1003 - def get_lat(self) -> torch.Tensor: - return self.lat - - @property - def dims(self) -> List[str]: - return [self.lat_name, self.lon_name] - - @property - def default_sizes(self) -> Dict[str, int]: - return {self.lat_name: 12, self.lon_name: 6} - - @property - def grid(self) -> Literal["equiangular", "legendre-gauss"]: - if torch.allclose( - self.lat[1:] - self.lat[:-1], - self.lat[1] - self.lat[0], - ): - return "equiangular" - else: - return "legendre-gauss" - - -@dataclasses.dataclass -class HEALPixCoordinates(HorizontalCoordinates): - """ - Defines a HEALPix (face, height, width) grid. See https://healpix.jpl.nasa.gov/ for - more information. - - Attributes: - face: 1-dimensional tensor of faces - height: 1-dimensional tensor of heights - width: 1-dimensional tensor of widths - """ - - face: torch.Tensor - height: torch.Tensor - width: torch.Tensor - - @property - def coords(self) -> Mapping[str, np.ndarray]: - return { - "face": self.face.cpu().numpy(), - "height": self.height.cpu().numpy(), - "width": self.width.cpu().numpy(), - } - - @property - def xyz(self) -> Tuple[float, float, float]: - hp = HEALPix(nside=len(self.height), order="ring") - return hp.healpix_to_xyz([self.face, self.height, self.width]) - - @property - def dims(self) -> List[str]: - return ["face", "height", "width"] - - @property - def default_sizes(cls) -> Dict[str, int]: - return {"face": 12, "width": 64, "height": 64} - - # TODO: https://github.com/ai2cm/full-model/issues/1003 - # This is currently the dummy solution. - def get_lat(self) -> torch.Tensor: - raise NotImplementedError( - "healpix does not support get_lat. If latitude is needed \ - for some reason, you may use this class's self.xyz property to derive it." - ) - - @property - def grid(self) -> Literal["healpix"]: - return "healpix" - - -class Dataset(torch.utils.data.Dataset, abc.ABC): - @abc.abstractproperty - def metadata(self) -> Mapping[str, VariableMetadata]: - ... - - @abc.abstractproperty - def area_weights(self) -> torch.Tensor: - ... - - @abc.abstractproperty - def horizontal_coordinates(self) -> HorizontalCoordinates: - ... - - @abc.abstractproperty - def sigma_coordinates(self) -> SigmaCoordinates: - ... - - @abc.abstractproperty - def is_remote(self) -> bool: - ... - - @abc.abstractmethod - def get_sample_by_time_slice( - self, time_slice: slice - ) -> Tuple[TensorDict, xr.DataArray]: - """ - Returns a sample of data for the given time slice. - - Args: - time_slice: The time slice to return data for. - - Returns: - A tuple whose first item is a mapping from variable - name to tensor of shape [n_time, n_lat, n_lon] and - whose second item is a time coordinate array. - """ - ... - - -@dataclasses.dataclass -class GriddedData: - """ - Data as required for pytorch training. - - The data is assumed to be gridded, and attributes are included for - performing operations on gridded data. - - Attributes: - loader: torch DataLoader, which returns batches of type - TensorMapping where keys indicate variable name. - Each tensor has shape - [batch_size, time_window_size, n_channels, n_lat, n_lon]. - metadata: Metadata for each variable. - area_weights: Weights for each grid cell, used for computing area-weighted - averages. Has shape [n_lat, n_lon]. - sigma_coordinates: Sigma coordinates for each grid cell, used for computing - pressure levels. - horizontal_coordinates: Lat/lon coordinates for the data. - timestep: Timestep of the model. - sampler: Optional sampler for the data loader. Provided to allow support for - distributed training. - """ - - loader: torch.utils.data.DataLoader - metadata: Mapping[str, VariableMetadata] - area_weights: torch.Tensor - sigma_coordinates: SigmaCoordinates - horizontal_coordinates: HorizontalCoordinates - timestep: datetime.timedelta - sampler: Optional[torch.utils.data.Sampler] = None - - @property - def dataset(self) -> Dataset: - return self.loader.dataset - - @property - def coords(self) -> Mapping[str, np.ndarray]: - return { - **self.horizontal_coordinates.coords, - **self.sigma_coordinates.coords, - } - - @property - def grid(self) -> Literal["equiangular", "legendre-gauss", "healpix"]: - """If the latitudes are equiangular, assume a regular grid and otherwise - assume a gaussian, or 'legendre-gauss' grid.""" - return self.horizontal_coordinates.grid diff --git a/fme/fme/core/data_loading/getters.py b/fme/fme/core/data_loading/getters.py deleted file mode 100644 index 6a933c5..0000000 --- a/fme/fme/core/data_loading/getters.py +++ /dev/null @@ -1,221 +0,0 @@ -import logging -import warnings -from typing import List, Sequence - -import numpy as np -import torch.utils.data -import xarray as xr -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler - -from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig -from fme.core.device import using_gpu -from fme.core.distributed import Distributed - -from ._xarray import XarrayDataset, as_index_slice, subset_dataset -from .data_typing import GriddedData -from .inference import ( - ExplicitIndices, - ForcingDataLoaderConfig, - InferenceDataLoaderConfig, - InferenceDataset, -) -from .requirements import DataRequirements -from .utils import BatchData - - -def _all_same(iterable, cmp=lambda x, y: x == y): - it = iter(iterable) - try: - first = next(it) - except StopIteration: - return True - return all(cmp(first, rest) for rest in it) - - -def get_datasets( - dataset_configs: Sequence[XarrayDataConfig], requirements: DataRequirements -) -> List[XarrayDataset]: - datasets = [] - for config in dataset_configs: - dataset = XarrayDataset(config, requirements) - index_slice = as_index_slice(config.subset, dataset) - dataset = subset_dataset(dataset, index_slice) - datasets.append(dataset) - return datasets - - -def get_dataset( - dataset_configs: Sequence[XarrayDataConfig], - requirements: DataRequirements, - strict: bool = True, -) -> torch.utils.data.ConcatDataset[XarrayDataset]: - datasets = get_datasets(dataset_configs, requirements) - - if not _all_same([d.metadata for d in datasets]): - if strict: - raise ValueError("Metadata for each ensemble member should be the same.") - else: - warnings.warn( - "Metadata for each ensemble member are not the same. You may be " - "concatenating incompatible datasets." - ) - sigma_coords = [d.sigma_coordinates for d in datasets] - ak, bk = list( - zip(*[(s.ak.cpu().numpy(), s.bk.cpu().numpy()) for s in sigma_coords]) - ) - if not (_all_same(ak, cmp=np.allclose) and _all_same(bk, cmp=np.allclose)): - if strict: - raise ValueError( - "Sigma coordinates for each ensemble member should be the same." - ) - else: - warnings.warn( - "Vertical coordinates for each ensemble member are not the same. You " - "may be concatenating incompatible datasets." - ) - - ensemble = torch.utils.data.ConcatDataset(datasets) - ensemble.metadata = datasets[0].metadata # type: ignore - ensemble.area_weights = datasets[0].area_weights # type: ignore - ensemble.sigma_coordinates = datasets[0].sigma_coordinates # type: ignore - ensemble.timestep = datasets[0].timestep # type: ignore - ensemble.horizontal_coordinates = datasets[0].horizontal_coordinates # type: ignore - ensemble.is_remote = any(d.is_remote for d in datasets) # type: ignore - return ensemble - - -def get_data_loader( - config: DataLoaderConfig, - train: bool, - requirements: DataRequirements, -) -> GriddedData: - """ - Args: - config: Parameters for the data loader. - train: Whether loader is intended for training or validation data; if True, - then data will be shuffled. - requirements: Data requirements for the model. - """ - dataset = get_dataset(config.dataset, requirements, strict=config.strict_ensemble) - dist = Distributed.get_instance() - - if dist.is_distributed(): - sampler = DistributedSampler(dataset, shuffle=train) - else: - sampler = RandomSampler(dataset) if train else None - - if dataset.is_remote: - # GCSFS and S3FS are not fork-safe, so we need to use forkserver - mp_context = "forkserver" - persistent_workers = True - else: - mp_context = None - persistent_workers = False - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=dist.local_batch_size(int(config.batch_size)), - num_workers=config.num_data_workers, - sampler=sampler, - drop_last=True, - pin_memory=using_gpu(), - collate_fn=BatchData.from_sample_tuples, - prefetch_factor=config.prefetch_factor, - multiprocessing_context=mp_context, - persistent_workers=persistent_workers, - ) - - if len(dataloader) == 0: - raise ValueError( - "No batches in dataloader: " - f"{len(dataloader.dataset)} samples, {len(dataloader)} batches. " - f"Batch size is {dataloader.batch_size}" - ) - - return GriddedData( - loader=dataloader, - metadata=dataset.metadata, - area_weights=dataset.area_weights, - sampler=sampler, - sigma_coordinates=dataset.sigma_coordinates, - timestep=dataset.timestep, - horizontal_coordinates=dataset.horizontal_coordinates, - ) - - -def get_inference_data( - config: InferenceDataLoaderConfig, - forward_steps_in_memory: int, - requirements: DataRequirements, -) -> GriddedData: - """ - Args: - config: Parameters for the data loader. - forward_steps_in_memory: Number of forward steps to keep in memory at once. - requirements: Data requirements for the model. - - Returns: - A data loader for inference with coordinates and metadata. - """ - dataset = InferenceDataset(config, forward_steps_in_memory, requirements) - - if dataset.is_remote: - # GCSFS and S3FS are not fork-safe, so we need to use forkserver - # persist workers since startup is slow - mp_context = "forkserver" - persistent_workers = True - else: - mp_context = None - persistent_workers = False - - logging.info(f"Multiprocessing inference context: {mp_context or 'fork'}") - - # we roll our own batching in InferenceDataset, which is why batch_size=None below - loader = torch.utils.data.DataLoader( - dataset, - batch_size=None, - num_workers=config.num_data_workers, - shuffle=False, - pin_memory=using_gpu(), - multiprocessing_context=mp_context, - persistent_workers=persistent_workers, - ) - return GriddedData( - loader=loader, - metadata=dataset.metadata, - area_weights=dataset.area_weights, - sigma_coordinates=dataset.sigma_coordinates, - timestep=dataset.timestep, - horizontal_coordinates=dataset.horizontal_coordinates, - ) - - -def get_forcing_data( - config: ForcingDataLoaderConfig, - forward_steps_in_memory: int, - requirements: DataRequirements, - initial_times: xr.DataArray, -) -> GriddedData: - """Return a GriddedData loader for forcing data only. This function determines the - start indices for the forcing data based on the initial times provided. - - Args: - config: Parameters for the forcing data loader. - forward_steps_in_memory: Number of forward steps to provide per window of - forcing data that will be returned by loader. - requirements: Data requirements for the forcing data. - initial_times: Desired initial times for the forcing data. This must be a 1D - data array, whose length determines the ensemble size. - - Returns: - A data loader for forcing data with coordinates and metadata. - """ - available_times = XarrayDataset(config.dataset, requirements).all_times - start_time_indices = [] - for time in initial_times.values: - start_time_indices.append(available_times.get_loc(time)) - inference_config = config.build_inference_config( - start_indices=ExplicitIndices(start_time_indices) - ) - return get_inference_data(inference_config, forward_steps_in_memory, requirements) diff --git a/fme/fme/core/data_loading/inference.py b/fme/fme/core/data_loading/inference.py deleted file mode 100644 index 1670f0e..0000000 --- a/fme/fme/core/data_loading/inference.py +++ /dev/null @@ -1,240 +0,0 @@ -import dataclasses -import datetime -from math import ceil -from typing import Sequence, Union - -import cftime -import numpy as np -import torch -import xarray as xr - -from fme.core.data_loading._xarray import XarrayDataset -from fme.core.data_loading.config import Slice, XarrayDataConfig -from fme.core.data_loading.data_typing import HorizontalCoordinates, SigmaCoordinates -from fme.core.data_loading.requirements import DataRequirements -from fme.core.data_loading.utils import BatchData -from fme.core.distributed import Distributed - - -@dataclasses.dataclass -class TimestampList: - """ - Configuration for a list of timestamps. - - Attributes: - times: List of timestamps. - timestamp_format: Format of the timestamps. - """ - - times: Sequence[str] - timestamp_format: str = "%Y-%m-%dT%H:%M:%S" - - def as_indices(self, time_index: xr.CFTimeIndex) -> np.ndarray: - datetimes = [ - cftime.datetime.strptime( - t, self.timestamp_format, calendar=time_index.calendar - ) - for t in self.times - ] - (indices,) = time_index.isin(datetimes).nonzero() - if len(indices) != len(self.times): - missing_times = set(datetimes) - set(time_index[indices]) - raise ValueError( - f"Inference initial condition timestamps {missing_times} " - "were not found in the dataset." - ) - return indices - - @property - def n_initial_conditions(self) -> int: - return len(self.times) - - -@dataclasses.dataclass -class InferenceInitialConditionIndices: - """ - Configuration of the indices for initial conditions during inference. - - Attributes: - n_initial_conditions: Number of initial conditions to use. - first: Index of the first initial condition. - interval: Interval between initial conditions. - """ - - n_initial_conditions: int - first: int = 0 - interval: int = 1 - - def __post_init__(self): - if self.interval < 0: - raise ValueError("interval must be positive") - - def as_indices(self) -> np.ndarray: - stop = self.n_initial_conditions * self.interval + self.first - return np.arange(self.first, stop, self.interval) - - -@dataclasses.dataclass -class ExplicitIndices: - """ - Configure indices providing them explicitly. - - Attributes: - list: List of integer indices. - """ - - list: Sequence[int] - - def as_indices(self) -> np.ndarray: - return np.array(self.list) - - @property - def n_initial_conditions(self) -> int: - return len(self.list) - - -@dataclasses.dataclass -class InferenceDataLoaderConfig: - """ - Configuration for inference data. - - This is like the `DataLoaderConfig` class, but with some additional - constraints. During inference, we have only one batch, so the number of - samples directly determines the size of that batch. - - Attributes: - dataset: Configuration to define the dataset. - start_indices: Configuration of the indices for initial conditions - during inference. This can be a list of timestamps, a list of - integer indices, or a slice configuration of the integer indices. - Values following the initial condition will still come from - the full dataset. - num_data_workers: Number of parallel workers to use for data loading. - """ - - dataset: XarrayDataConfig - start_indices: Union[ - InferenceInitialConditionIndices, ExplicitIndices, TimestampList - ] - num_data_workers: int = 0 - - def __post_init__(self): - if self.dataset.subset != Slice(None, None, None): - raise ValueError("Inference data may not be subset.") - - @property - def n_samples(self) -> int: - return self.start_indices.n_initial_conditions - - -@dataclasses.dataclass -class ForcingDataLoaderConfig: - """ - Configuration for the forcing data. - - Attributes: - dataset: Configuration to define the dataset. - num_data_workers: Number of parallel workers to use for data loading. - """ - - dataset: XarrayDataConfig - num_data_workers: int = 0 - - def __post_init__(self): - if self.dataset.subset != Slice(None, None, None): - raise ValueError("Inference data may not be subset.") - - def build_inference_config(self, start_indices: ExplicitIndices): - return InferenceDataLoaderConfig( - dataset=self.dataset, - num_data_workers=self.num_data_workers, - start_indices=start_indices, - ) - - -class InferenceDataset(torch.utils.data.Dataset): - def __init__( - self, - config: InferenceDataLoaderConfig, - forward_steps_in_memory: int, - requirements: DataRequirements, - ): - dataset = XarrayDataset(config.dataset, requirements=requirements) - self._dataset = dataset - self._sigma_coordinates = dataset.sigma_coordinates - self._metadata = dataset.metadata - self._area_weights = dataset.area_weights - self._horizontal_coordinates = dataset.horizontal_coordinates - self._timestep = dataset.timestep - self._forward_steps_in_memory = forward_steps_in_memory - self._total_steps = requirements.n_timesteps - 1 - self._is_remote = dataset.is_remote - self.n_samples = config.n_samples # public attribute - if isinstance(config.start_indices, TimestampList): - self._start_indices = config.start_indices.as_indices(dataset.all_times) - else: - self._start_indices = config.start_indices.as_indices() - self._validate_n_forward_steps() - - def __getitem__(self, index) -> BatchData: - dist = Distributed.get_instance() - i_start = index * self._forward_steps_in_memory - sample_tuples = [] - for i_sample in range(self.n_samples): - # check if sample is one this local rank should process - if i_sample % dist.world_size != dist.rank: - continue - i_window_start = i_start + self._start_indices[i_sample] - i_window_end = i_window_start + self._forward_steps_in_memory + 1 - if i_window_end > (self._total_steps + self._start_indices[i_sample]): - i_window_end = self._total_steps + self._start_indices[i_sample] + 1 - window_time_slice = slice(i_window_start, i_window_end) - sample_tuples.append( - self._dataset.get_sample_by_time_slice(window_time_slice) - ) - result = BatchData.from_sample_tuples(sample_tuples) - assert result.times.shape[0] == self.n_samples // dist.world_size - return result - - def __len__(self) -> int: - # The ceil is necessary so if the last batch is smaller - # than the rest the ratio will be rounded up and the last batch - # will be included in the loading - return int(ceil(self._total_steps / self._forward_steps_in_memory)) - - @property - def sigma_coordinates(self) -> SigmaCoordinates: - return self._sigma_coordinates - - @property - def metadata(self) -> xr.Dataset: - return self._metadata - - @property - def area_weights(self) -> xr.DataArray: - return self._area_weights - - @property - def horizontal_coordinates(self) -> HorizontalCoordinates: - return self._horizontal_coordinates - - @property - def timestep(self) -> datetime.timedelta: - return self._timestep - - @property - def is_remote(self) -> bool: - return self._is_remote - - @property - def n_forward_steps(self) -> int: - return self._total_steps - - def _validate_n_forward_steps(self): - max_steps = self._dataset.total_timesteps - self._start_indices[-1] - 1 - if self._total_steps > max_steps: - raise ValueError( - f"The number of forward inference steps ({self._total_steps}) must " - f"be less than or equal to the number of possible steps ({max_steps})" - f"in dataset after the last initial condition's start index." - ) diff --git a/fme/fme/core/data_loading/requirements.py b/fme/fme/core/data_loading/requirements.py deleted file mode 100644 index 7aecf19..0000000 --- a/fme/fme/core/data_loading/requirements.py +++ /dev/null @@ -1,8 +0,0 @@ -import dataclasses -from typing import List - - -@dataclasses.dataclass -class DataRequirements: - names: List[str] - n_timesteps: int diff --git a/fme/fme/core/dataset/__init__.py b/fme/fme/core/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fme/fme/core/dataset/config.py b/fme/fme/core/dataset/config.py new file mode 100644 index 0000000..a7ecbe6 --- /dev/null +++ b/fme/fme/core/dataset/config.py @@ -0,0 +1,280 @@ +import dataclasses +from datetime import timedelta +from typing import Literal, Mapping, Optional, Union + +import numpy as np +import pandas as pd +import torch +import xarray as xr + +from fme.core.typing_ import Slice, TensorDict + + +@dataclasses.dataclass +class TimeSlice: + """ + Configuration of a slice of times. Step is an integer-valued index step. + + Note: start_time and stop_time may be provided as partial time strings and the + stop_time will be included in the slice. See more details in `Xarray docs`_. + + Parameters: + start_time: Start time of the slice. + stop_time: Stop time of the slice. + step: Step of the slice. + + .. _Xarray docs: + https://docs.xarray.dev/en/latest/user-guide/weather-climate.html#non-standard-calendars-and-dates-outside-the-nanosecond-precision-range + """ # noqa: E501 + + start_time: Optional[str] = None + stop_time: Optional[str] = None + step: Optional[int] = None + + def slice(self, time: xr.CFTimeIndex) -> slice: + return time.slice_indexer(self.start_time, self.stop_time, self.step) + + +def _convert_interval_to_int( + interval: pd.Timedelta, + timestep: timedelta, +): + """Convert interval to integer number of timesteps.""" + if interval % timestep != timedelta(0): + raise ValueError( + f"Requested interval length {interval} is not a " + f"multiple of the timestep {timestep}." + ) + + return interval // timestep + + +@dataclasses.dataclass +class RepeatedInterval: + """ + Configuration for a repeated interval within a block. This configuration + is used to generate a boolean mask for a dataset that will return values + within the interval and repeat that throughout the dataset. + + Parameters: + interval_length: Length of the interval to return values from + start: Start position of the interval within the repeat block. + block_length: Total length of the block to be repeated over the length of + the dataset, including the interval length. + + Note: + The interval_length, start, and block_length can be provided as either + all integers or all strings representing timedeltas of the block. + If provided as strings, the timestep must be provided when calling + `get_boolean_mask`. + + Examples: + To return values from the first 3 items of every 6 items, use: + + >>> RepeatedInterval(interval_length=3, repeat=6, start=0) + + To return a days worth of values starting after 2 days from every 7-day + block, use: + + >>> RepeatedInterval(interval_length="1d", repeat="7d", start="2d") + """ + + interval_length: Union[int, str] + start: Union[int, str] + block_length: Union[int, str] + + def __post_init__(self): + types = {type(self.interval_length), type(self.block_length), type(self.start)} + if len(types) > 1: + raise ValueError( + "All attributes of RepeatedInterval must be of the " + "same type (either all int or all str)." + ) + + self._is_time_delta_str = isinstance(self.interval_length, str) + + if self._is_time_delta_str: + self.interval_length = pd.Timedelta(self.interval_length) + self.block_length = pd.Timedelta(self.block_length) + self.start = pd.Timedelta(self.start) + + def get_boolean_mask( + self, length: int, timestep: Optional[timedelta] = None + ) -> np.ndarray: + """ + Return a boolean mask for the repeated interval. + + Args: + length: Length of the dataset. + timestep: Timestep of the dataset. + """ + if self._is_time_delta_str: + if timestep is None: + raise ValueError( + "Timestep must be provided when using time deltas " + "for RepeatedInterval." + ) + + interval_length = _convert_interval_to_int(self.interval_length, timestep) + block_length = _convert_interval_to_int(self.block_length, timestep) + start = _convert_interval_to_int(self.start, timestep) + else: + interval_length = self.interval_length + block_length = self.block_length + start = self.start + + if start + interval_length > block_length: + raise ValueError( + "The interval (with start point) must fit within the repeat block." + ) + + block = np.zeros(block_length, dtype=bool) + block[start : start + interval_length] = True + num_blocks = length // block_length + 1 + mask = np.tile(block, num_blocks)[:length] + return mask + + +@dataclasses.dataclass +class OverwriteConfig: + """Configuration to overwrite field values in XarrayDataset. + + Parameters: + constant: Fill field with constant value. + multiply_scalar: Multiply field by scalar value. + """ + + constant: Mapping[str, float] = dataclasses.field(default_factory=dict) + multiply_scalar: Mapping[str, float] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + key_overlap = set(self.constant.keys()) & set(self.multiply_scalar.keys()) + if key_overlap: + raise ValueError( + "OverwriteConfig cannot have the same variable in both constant " + f"and multiply_scalar: {key_overlap}" + ) + + def apply(self, tensors: TensorDict) -> TensorDict: + for var, fill_value in self.constant.items(): + data = tensors[var] + tensors[var] = torch.ones_like(data) * torch.tensor( + fill_value, dtype=data.dtype, device=data.device + ) + for var, multiplier in self.multiply_scalar.items(): + data = tensors[var] + tensors[var] = data * torch.tensor( + multiplier, dtype=data.dtype, device=data.device + ) + return tensors + + @property + def variables(self): + return set(self.constant.keys()) | set(self.multiply_scalar.keys()) + + +@dataclasses.dataclass +class FillNaNsConfig: + """ + Configuration to fill NaNs with a constant value or others. + + Parameters: + method: Type of fill operation. Currently only 'constant' is supported. + value: Value to fill NaNs with. + """ + + method: Literal["constant"] = "constant" + value: float = 0.0 + + +@dataclasses.dataclass +class XarrayDataConfig: + """ + Parameters: + data_path: Path to the data. + file_pattern: Glob pattern to match files in the data_path. + n_repeats: Number of times to repeat the dataset (in time). It is up + to the user to ensure that the input dataset to repeat results in + data that is reasonably continuous across repetitions. + engine: Backend used in xarray.open_dataset call. + spatial_dimensions: Specifies the spatial dimensions for the grid, default + is lat/lon. + subset: Slice defining a subset of the XarrayDataset to load. This can + either be a `Slice` of integer indices or a `TimeSlice` of timestamps. + infer_timestep: Whether to infer the timestep from the provided data. + This should be set to True (the default) for ACE training. It may + be useful to toggle this to False for applications like downscaling, + which do not depend on the timestep of the data and therefore lack + the additional requirement that the data be ordered and evenly + spaced in time. It must be set to True if n_repeats > 1 in order + to be able to infer the full time coordinate. + dtype: Data type to cast the data to. If None, no casting is done. It is + required that 'torch.{dtype}' is a valid dtype. + overwrite: Optional OverwriteConfig to overwrite loaded field values. If this is + configured for a renamed field, the key should be the final updated name. + renamed_variables: Optional mapping of {old_name: new_name} to rename variables + fill_nans: Optional FillNaNsConfig to fill NaNs with a constant value. + + Examples: + If data is stored in a directory with multiple netCDF files which can be + concatenated along the time dimension, use: + + >>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") # doctest: +IGNORE_OUTPUT + + If data is stored in a single zarr store at ``/some/directory/dataset.zarr``, + use: + + >>> fme.ace.XarrayDataConfig( + ... data_path="/some/directory", + ... file_pattern="dataset.zarr", + ... engine="zarr" + ... ) # doctest: +IGNORE_OUTPUT + """ # noqa: E501 + + data_path: str + file_pattern: str = "*.nc" + n_repeats: int = 1 + engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4" + spatial_dimensions: Literal["healpix", "latlon"] = "latlon" + subset: Union[Slice, TimeSlice, RepeatedInterval] = dataclasses.field( + default_factory=Slice + ) + infer_timestep: bool = True + dtype: Optional[str] = "float32" + overwrite: OverwriteConfig = dataclasses.field(default_factory=OverwriteConfig) + renamed_variables: Optional[Mapping[str, str]] = None + fill_nans: Optional[FillNaNsConfig] = None + + def _default_file_pattern_check(self): + if self.engine == "zarr" and self.file_pattern == "*.nc": + raise ValueError( + "The file pattern is set to the default NetCDF file pattern *.nc " + "but the engine is specified as 'zarr'. Please set " + "`XarrayDataConfig.file_pattern` to match the zarr filename." + ) + + def __post_init__(self): + if self.n_repeats > 1 and not self.infer_timestep: + raise ValueError( + "infer_timestep must be True if n_repeats is greater than 1" + ) + if self.dtype is None: + self.torch_dtype = None + else: + try: + self.torch_dtype = getattr(torch, self.dtype) + except AttributeError: + raise ValueError(f"Invalid dtype '{self.dtype}'") + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError(f"Invalid dtype '{self.dtype}'") + + # Raise error if overwrite variables are in the keys of renamed variables + if self.renamed_variables is not None: + overlap = set(self.overwrite.variables) & set(self.renamed_variables.keys()) + if overlap: + raise ValueError( + "Variables in overwrite should not be the original names before " + f"renaming: {overlap}. " + "Please use the final renamed variables in the overwrite config." + ) + self._default_file_pattern_check() diff --git a/fme/fme/core/dataset/data_typing.py b/fme/fme/core/dataset/data_typing.py new file mode 100644 index 0000000..3cd96b4 --- /dev/null +++ b/fme/fme/core/dataset/data_typing.py @@ -0,0 +1,29 @@ +import abc +from collections import namedtuple +from typing import Tuple + +import torch +import xarray as xr + +from fme.core.typing_ import TensorDict + +VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"]) + + +class Dataset(torch.utils.data.Dataset, abc.ABC): + @abc.abstractmethod + def get_sample_by_time_slice( + self, time_slice: slice + ) -> Tuple[TensorDict, xr.DataArray]: + """ + Returns a sample of data for the given time slice. + + Args: + time_slice: The time slice to return data for. + + Returns: + A tuple whose first item is a mapping from variable + name to tensor of shape [n_time, n_lat, n_lon] and + whose second item is a time coordinate array. + """ + ... diff --git a/fme/fme/core/dataset/getters.py b/fme/fme/core/dataset/getters.py new file mode 100644 index 0000000..e85e9cc --- /dev/null +++ b/fme/fme/core/dataset/getters.py @@ -0,0 +1,46 @@ +import warnings +from typing import List, Optional, Sequence, Tuple + +import torch.utils.data + +from fme.core.dataset.config import XarrayDataConfig +from fme.core.dataset.xarray import DatasetProperties, XarrayDataset, get_xarray_dataset + +from .requirements import DataRequirements + + +def get_datasets( + dataset_configs: Sequence[XarrayDataConfig], + requirements: DataRequirements, + strict: bool = True, +) -> Tuple[List[XarrayDataset], DatasetProperties]: + datasets = [] + properties: Optional[DatasetProperties] = None + for config in dataset_configs: + dataset, new_properties = get_xarray_dataset(config, requirements) + datasets.append(dataset) + if properties is None: + properties = new_properties + elif not strict: + try: + properties.update(new_properties) + except ValueError as e: + warnings.warn( + f"Metadata for each ensemble member are not the same: {e}" + ) + else: + properties.update(new_properties) + if properties is None: + raise ValueError("At least one dataset must be provided.") + + return datasets, properties + + +def get_dataset( + dataset_configs: Sequence[XarrayDataConfig], + requirements: DataRequirements, + strict: bool = True, +) -> Tuple[torch.utils.data.ConcatDataset[XarrayDataset], DatasetProperties]: + datasets, properties = get_datasets(dataset_configs, requirements, strict=strict) + ensemble = torch.utils.data.ConcatDataset(datasets) + return ensemble, properties diff --git a/fme/fme/core/dataset/requirements.py b/fme/fme/core/dataset/requirements.py new file mode 100644 index 0000000..65e2ab6 --- /dev/null +++ b/fme/fme/core/dataset/requirements.py @@ -0,0 +1,16 @@ +import dataclasses +from typing import List + + +@dataclasses.dataclass +class DataRequirements: + """ + The requirements for batches (time windows) of loaded data. + + Parameters: + names: Names of the variables to load. + n_timesteps: Number of timesteps to load in each batch window. + """ + + names: List[str] + n_timesteps: int diff --git a/fme/fme/core/data_loading/test_utils.py b/fme/fme/core/dataset/test_utils.py similarity index 75% rename from fme/fme/core/data_loading/test_utils.py rename to fme/fme/core/dataset/test_utils.py index e82f2f6..2091bc7 100644 --- a/fme/fme/core/data_loading/test_utils.py +++ b/fme/fme/core/dataset/test_utils.py @@ -1,17 +1,20 @@ import copy import datetime +from typing import List import numpy as np import pytest import torch import xarray as xr -from fme.core.data_loading.data_typing import ( +from fme.core.coordinates import ( + DimSize, HEALPixCoordinates, HorizontalCoordinates, LatLonCoordinates, ) -from fme.core.data_loading.utils import ( + +from .utils import ( _get_indexers, as_broadcasted_tensor, decode_timestep, @@ -31,12 +34,12 @@ def get_sizes( spatial_dims: HorizontalCoordinates = LatLonCoordinates( lon=torch.Tensor(np.arange(6)), lat=torch.Tensor(np.arange(12)), - lat_name=LAT_DIM, - lon_name=LON_DIM, - ) + loaded_lat_name=LAT_DIM, + loaded_lon_name=LON_DIM, + ), ): - spatial_sizes: dict = copy.deepcopy(spatial_dims.default_sizes) - spatial_sizes[TIME_DIM] = 3 + spatial_sizes: List[DimSize] = copy.deepcopy(spatial_dims.loaded_default_sizes) + spatial_sizes.append(DimSize(TIME_DIM, 3)) return spatial_sizes @@ -44,13 +47,13 @@ def create_reference_dataset( spatial_dims: HorizontalCoordinates = LatLonCoordinates( lon=torch.Tensor(np.arange(6)), lat=torch.Tensor(np.arange(12)), - lat_name=LAT_DIM, - lon_name=LON_DIM, - ) + loaded_lat_name=LAT_DIM, + loaded_lon_name=LON_DIM, + ), ): - dims = [TIME_DIM] + spatial_dims.dims - sizes = get_sizes(spatial_dims=spatial_dims) - shape = tuple(sizes[dim] for dim in dims) + dims = [TIME_DIM] + spatial_dims.loaded_dims + dim_sizes = get_sizes(spatial_dims=spatial_dims) + shape = tuple(dim_size.size for dim_size in dim_sizes) data = np.arange(np.prod(shape)).reshape(shape) coords = [np.arange(size) for size in shape] full = xr.DataArray(data, dims=dims, coords=coords, name=FULL_NAME) @@ -71,11 +74,13 @@ def test_infer_horizontal_dimension_names(lon_dim, lat_dim, warns): spatial_dims = LatLonCoordinates( lon=torch.Tensor(np.arange(6)), lat=torch.Tensor(np.arange(12)), - lat_name=lat_dim, - lon_name=lon_dim, + loaded_lat_name=lat_dim, + loaded_lon_name=lon_dim, ) ds = create_reference_dataset(spatial_dims=spatial_dims) expected = [lon_dim, lat_dim] + for dim in expected: + assert dim in ds.dims if warns: with pytest.warns(UserWarning, match="Familiar"): infer_horizontal_dimension_names(ds) @@ -91,18 +96,32 @@ def test_infer_horizontal_dimension_names_healpix(): width=torch.Tensor(np.arange(64)), height=torch.Tensor(np.arange(64)), ) - ds = create_reference_dataset(hpx_coords) - expected = hpx_coords.dims + ds = create_reference_dataset(spatial_dims=hpx_coords) + expected = hpx_coords.loaded_dims result = infer_horizontal_dimension_names(ds) assert result == expected +@pytest.mark.parametrize( + ["coordinate_type", "coord_sizes"], + [ + pytest.param(LatLonCoordinates, {"lat": 90, "lon": 180}), + pytest.param(HEALPixCoordinates, {"face": 12, "height": 64, "width": 64}), + ], +) +def test_horizonal_dimension_sizes(coordinate_type, coord_sizes): + coords = {name: torch.Tensor(np.arange(size)) for name, size in coord_sizes.items()} + horizontal_coords = coordinate_type(**coords) + for name, size in coord_sizes.items(): + assert len(horizontal_coords.coords[name]) == size + + def test_infer_horizontal_dimension_names_error(): spatial_dims = LatLonCoordinates( lon=torch.Tensor(np.arange(6)), lat=torch.Tensor(np.arange(12)), - lat_name="foo", - lon_name="bar", + loaded_lat_name="foo", + loaded_lon_name="bar", ) ds = create_reference_dataset(spatial_dims=spatial_dims) ds = ds.isel(time=0) @@ -123,8 +142,10 @@ def test_infer_horizontal_dimension_names_error(): ids=lambda x: f"{x}", ) def test__get_indexers(variable_dims, expected): - sizes = get_sizes() - shape = tuple(sizes[dim] for dim in variable_dims) + dim_sizes = get_sizes() + shape = tuple( + dim_size.size for dim_size in dim_sizes if dim_size.name in variable_dims + ) variable = xr.Variable(variable_dims, np.zeros(shape)) dims = (TIME_DIM, LAT_DIM, LON_DIM) result = _get_indexers(variable, dims) diff --git a/fme/fme/core/data_loading/test__xarray.py b/fme/fme/core/dataset/test_xarray.py similarity index 72% rename from fme/fme/core/data_loading/test__xarray.py rename to fme/fme/core/dataset/test_xarray.py index 014c5bb..14ba0fa 100755 --- a/fme/fme/core/data_loading/test__xarray.py +++ b/fme/fme/core/dataset/test_xarray.py @@ -12,29 +12,30 @@ import torch import xarray as xr -from fme.core.data_loading._xarray import ( +from fme.core.coordinates import LatLonCoordinates +from fme.core.dataset.config import ( + FillNaNsConfig, + OverwriteConfig, + TimeSlice, + XarrayDataConfig, +) +from fme.core.dataset.requirements import DataRequirements +from fme.core.dataset.xarray import ( XarrayDataset, get_cumulative_timesteps, get_file_local_index, get_raw_times, get_timestep, - repeat_and_increment_times, -) -from fme.core.data_loading.config import ( - DataLoaderConfig, - Slice, - TimeSlice, - XarrayDataConfig, -) -from fme.core.data_loading.getters import get_data_loader, get_dataset -from fme.core.data_loading.requirements import DataRequirements -from fme.core.data_loading.utils import ( - as_broadcasted_tensor, - infer_horizontal_dimension_names, + get_xarray_dataset, + repeat_and_increment_time, ) +from fme.core.typing_ import Slice + +from .utils import as_broadcasted_tensor, infer_horizontal_dimension_names SLICE_NONE = slice(None) MOCK_DATA_FREQ = "3h" +MOCK_DATA_START_DATE = "2003-03" @dataclasses.dataclass @@ -75,6 +76,7 @@ def _get_data( file_freq, step_freq, calendar, + with_nans=False, ) -> MockData: """Constructs an xarray dataset and saves to disk in netcdf format.""" obs_times = xr.cftime_range( @@ -109,15 +111,17 @@ def _get_data( last = start_times[i + 1] else: last = obs_times[-1] + obs_delta - times = xr.cftime_range( + time = xr.cftime_range( first, last, freq=step_freq, calendar=calendar, inclusive="left" ) data_vars: Dict[str, Union[float, xr.DataArray]] = {**ak, **bk} for var_name in var_names: - data = np.random.randn(len(times), n_lat, n_lon).astype(np.float32) + data = np.random.randn(len(time), n_lat, n_lon).astype(np.float32) + if with_nans: + data[0, :, 0] = np.nan data_vars[var_name] = xr.DataArray(data, dims=("time", "lat", "lon")) - data_varying_scalar = np.random.randn(len(times)).astype(np.float32) + data_varying_scalar = np.random.randn(len(time)).astype(np.float32) data_vars["varying_scalar_var"] = xr.DataArray( data_varying_scalar, dims=("time",) ) @@ -126,7 +130,7 @@ def _get_data( data_vars["constant_scalar_var"] = constant_scalar_var coords = { - "time": xr.DataArray(times, dims=("time",)), + "time": xr.DataArray(time, dims=("time",)), "lat": xr.DataArray(np.arange(n_lat, dtype=np.float32), dims=("lat",)), "lon": xr.DataArray(np.arange(n_lon, dtype=np.float32), dims=("lon",)), } @@ -151,15 +155,16 @@ def _get_data( return MockData(tmpdir, obs_times, start_times, start_indices, variable_names) -def get_mock_monthly_netcdfs(tmp_path_factory, dirname) -> MockData: +def get_mock_monthly_netcdfs(tmp_path_factory, dirname, with_nans=False) -> MockData: return _get_data( tmp_path_factory, dirname, - start="2003-03", + start=MOCK_DATA_START_DATE, end="2003-06", file_freq="MS", step_freq=MOCK_DATA_FREQ, calendar="standard", + with_nans=with_nans, ) @@ -168,6 +173,11 @@ def mock_monthly_netcdfs(tmp_path_factory) -> MockData: return get_mock_monthly_netcdfs(tmp_path_factory, "month") +@pytest.fixture(scope="session") +def mock_monthly_netcdfs_with_nans(tmp_path_factory) -> MockData: + return get_mock_monthly_netcdfs(tmp_path_factory, "month_with_nans", with_nans=True) + + @pytest.fixture(scope="session") def mock_monthly_zarr(tmp_path_factory, mock_monthly_netcdfs) -> MockData: zarr_parent = tmp_path_factory.mktemp("zarr") @@ -250,7 +260,7 @@ def _test_monthly_values( expected_n_samples = len(mock_data.obs_times) - 1 assert len(dataset) == expected_n_samples - arrays, times = dataset[global_idx] + arrays, time = dataset[global_idx] with xr.open_mfdataset( mock_data.tmpdir.glob(file_pattern), engine=engine, @@ -259,7 +269,7 @@ def _test_monthly_values( coords="minimal", ) as ds: target_times = ds["time"][global_idx : global_idx + 2].drop_vars("time") - xr.testing.assert_equal(times, target_times) + xr.testing.assert_equal(time, target_times) lon_dim, lat_dim = infer_horizontal_dimension_names(ds) dims = ("time", str(lat_dim), str(lon_dim)) shape = (2, ds.sizes[lat_dim], ds.sizes[lon_dim]) @@ -309,17 +319,13 @@ def test_XarrayDataset_monthly_n_timesteps(mock_monthly_netcdfs, n_samples): mock_data: MockData = mock_monthly_netcdfs if len(mock_data.var_names.initial_condition_names) != 0: return - config = DataLoaderConfig( - [XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(stop=n_samples))], - 1, - 0, - ) + config = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(stop=n_samples)) n_forward_steps = 4 requirements = DataRequirements( names=mock_data.var_names.all_names + ["x"], n_timesteps=n_forward_steps + 1, ) - dataset = get_dataset(config.dataset, requirements) + dataset, _ = get_xarray_dataset(config, requirements) if n_samples is None: assert len(dataset) == len(mock_data.obs_times) - n_forward_steps else: @@ -387,25 +393,36 @@ def test_XarrayDataset_yearly(mock_yearly_netcdfs, global_idx): target_times = ds["time"][global_idx : global_idx + n_steps].drop_vars( "time" ) - data, times = dataset[global_idx] - data = data[var_name] - assert data.shape[0] == n_steps - assert torch.equal(data, target_data) - xr.testing.assert_equal(times, target_times) + data, time = dataset[global_idx] + data_tensor = data[var_name] + assert data_tensor.shape[0] == n_steps + assert torch.equal(data_tensor, target_data) + xr.testing.assert_equal(time, target_times) + + +def test_dataset_dtype_casting(mock_monthly_netcdfs): + mock_data: MockData = mock_monthly_netcdfs + config = XarrayDataConfig(data_path=mock_data.tmpdir, dtype="bfloat16") + requirements = DataRequirements(names=mock_data.var_names.all_names, n_timesteps=2) + dataset = XarrayDataset(config=config, requirements=requirements) + assert isinstance(dataset.horizontal_coordinates, LatLonCoordinates) + assert dataset.horizontal_coordinates.lat.dtype == torch.bfloat16 + assert dataset.horizontal_coordinates.lon.dtype == torch.bfloat16 + assert dataset.vertical_coordinate.ak.dtype == torch.bfloat16 + assert dataset.vertical_coordinate.bk.dtype == torch.bfloat16 + data, _ = dataset[0] + for tensor in data.values(): + assert tensor.dtype == torch.bfloat16 def test_time_invariant_variable_is_repeated(mock_monthly_netcdfs): mock_data: MockData = mock_monthly_netcdfs - config = DataLoaderConfig( - [XarrayDataConfig(data_path=mock_data.tmpdir)], - batch_size=1, - num_data_workers=0, - ) + config = XarrayDataConfig(data_path=mock_data.tmpdir) requirements = DataRequirements(names=mock_data.var_names.all_names, n_timesteps=15) - data = get_data_loader(config=config, train=False, requirements=requirements) - batch, _ = data.loader.dataset[0] - assert batch["constant_var"].shape[0] == 15 - assert batch["constant_scalar_var"].shape == (15, 4, 8) + dataset = XarrayDataset(config=config, requirements=requirements) + data = dataset[0][0] + assert data["constant_var"].shape[0] == 15 + assert data["constant_scalar_var"].shape == (15, 4, 8) def _get_repeat_dataset( @@ -479,7 +496,7 @@ def test_time_index(mock_monthly_netcdfs): ) last_sample_init_time = len(mock_monthly_netcdfs.obs_times) - n_timesteps + 1 obs_times = mock_monthly_netcdfs.obs_times[:last_sample_init_time] - assert dataset.sample_start_times.equals(xr.CFTimeIndex(obs_times)) + assert dataset.sample_start_time.equals(xr.CFTimeIndex(obs_times)) @pytest.mark.parametrize("infer_timestep", [True, False]) @@ -559,7 +576,7 @@ def test_repeat_and_increment_times(n_repeats): raw_periods = [periods_a, periods_b] raw_total_periods = sum(raw_periods) - result = repeat_and_increment_times(raw_times, n_repeats, delta) + result = repeat_and_increment_time(raw_times, n_repeats, delta) full_periods = [len(times) for times in result] full_total_periods = sum(full_periods) @@ -573,12 +590,130 @@ def test_repeat_and_increment_times(n_repeats): np.testing.assert_equal(result_concatenated, expected_concatenated) -def test_available_times(mock_monthly_netcdfs): +@pytest.mark.parametrize("n_repeats", [1, 3]) +def test_all_times(mock_monthly_netcdfs, n_repeats): + n_timesteps = 2 # Arbitrary for this test + dataset = _get_repeat_dataset(mock_monthly_netcdfs, n_timesteps, n_repeats) + expected_periods = n_repeats * len(mock_monthly_netcdfs.obs_times) + expected = xr.cftime_range( + MOCK_DATA_START_DATE, periods=expected_periods, freq=MOCK_DATA_FREQ + ) + result = dataset.all_times + assert result.equals(expected) + + +def test_get_sample_by_time_slice_times_n_repeats(mock_monthly_netcdfs: MockData): + n_timesteps = 2 # Arbitrary for this test + n_repeats = 3 + repeated_dataset = _get_repeat_dataset(mock_monthly_netcdfs, n_timesteps, n_repeats) + + # Pick a slice that is outside the range of the unrepeated data + unrepeated_length = len(repeated_dataset.all_times) // n_repeats + time_slice = slice(unrepeated_length, unrepeated_length + 3) + + _, result = repeated_dataset.get_sample_by_time_slice(time_slice) + expected = xr.DataArray( + repeated_dataset.all_times[time_slice].values, dims=["time"] + ) + xr.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype,expected_torch_dtype", [("int16", torch.int16), (None, None)] +) +def test_dataset_config_dtype(dtype, expected_torch_dtype): + config = XarrayDataConfig(data_path="path/to/data", dtype=dtype) + assert config.torch_dtype == expected_torch_dtype + + +def test_dataset_config_dtype_raises(): + with pytest.raises(ValueError): + XarrayDataConfig(data_path="path/to/data", dtype="invalid_dtype") + + +def test_renaming(mock_monthly_netcdfs): + # stepper in/out names should be variables after renaming + stepper_variables = ["foo", "bar_new"] + config = XarrayDataConfig( + data_path=mock_monthly_netcdfs.tmpdir, renamed_variables={"bar": "bar_new"} + ) + dataset = XarrayDataset( + config, + DataRequirements(names=stepper_variables, n_timesteps=2), + ) + data, _ = dataset[0] + assert "bar_new" in data + assert "bar" not in data + + +def test_fill_nans(mock_monthly_netcdfs_with_nans): + nan_config = FillNaNsConfig() + config = XarrayDataConfig( + data_path=mock_monthly_netcdfs_with_nans.tmpdir, fill_nans=nan_config + ) + requirements = DataRequirements( + names=mock_monthly_netcdfs_with_nans.var_names.all_names, n_timesteps=2 + ) + dataset = XarrayDataset(config, requirements) + data, _ = dataset[0] + assert torch.all(data["foo"][0, :, 0] == 0) + + +def test_keep_nans(mock_monthly_netcdfs_with_nans): + config_keep_nan = XarrayDataConfig(data_path=mock_monthly_netcdfs_with_nans.tmpdir) + requirements = DataRequirements( + names=mock_monthly_netcdfs_with_nans.var_names.all_names, n_timesteps=2 + ) + dataset = XarrayDataset(config_keep_nan, requirements) + data_with_nan, _ = dataset[0] + assert torch.all(torch.isnan(data_with_nan["foo"][0, :, 0])) + + +def test_overwrite(mock_monthly_netcdfs): + const = -10 + multiple = 3.5 + + overwrite_config = OverwriteConfig( + constant={"foo": const}, + multiply_scalar={"bar": multiple}, + ) + config = XarrayDataConfig(data_path=mock_monthly_netcdfs.tmpdir) + n_timesteps = 2 dataset = XarrayDataset( config, DataRequirements( - names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=10 + names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=n_timesteps ), + )[0][0] + + config_overwrite = XarrayDataConfig( + data_path=mock_monthly_netcdfs.tmpdir, overwrite=overwrite_config ) - assert dataset.all_times.equals(xr.CFTimeIndex(mock_monthly_netcdfs.obs_times)) + n_timesteps = 2 + dataset_overwrite = XarrayDataset( + config_overwrite, + DataRequirements( + names=mock_monthly_netcdfs.var_names.all_names, n_timesteps=n_timesteps + ), + )[0][0] + + for v in ["foo", "bar"]: + assert dataset_overwrite[v].dtype == dataset[v].dtype + assert dataset_overwrite[v].device == dataset[v].device + assert torch.equal( + dataset_overwrite["foo"], torch.ones_like(dataset["foo"]) * const + ) + assert torch.equal(dataset_overwrite["bar"], dataset["bar"] * multiple) + + +def test_overwrite_raises_error_on_original_name(mock_monthly_netcdfs): + overwrite_config = OverwriteConfig( + constant={"foo": 3}, + ) + with pytest.raises(ValueError): + XarrayDataConfig( + data_path=mock_monthly_netcdfs.tmpdir, + overwrite=overwrite_config, + renamed_variables={"foo": "foo_new"}, + ) diff --git a/fme/fme/core/data_loading/utils.py b/fme/fme/core/dataset/utils.py similarity index 65% rename from fme/fme/core/data_loading/utils.py rename to fme/fme/core/dataset/utils.py index 05cb8a5..5146b52 100644 --- a/fme/fme/core/data_loading/utils.py +++ b/fme/fme/core/dataset/utils.py @@ -1,23 +1,15 @@ -import dataclasses import datetime import warnings from typing import Hashable, List, Optional, Sequence, Tuple -import cftime -import numpy as np import torch import xarray as xr -from torch.utils.data import default_collate - -from fme.core.typing_ import TensorMapping - -from .data_typing import HorizontalCoordinates SLICE_NONE = slice(None) -def infer_horizontal_dimension_names(ds: xr.Dataset) -> List[Hashable]: - hdims: List[Hashable] +def infer_horizontal_dimension_names(ds: xr.Dataset) -> List[str]: + hdims: List[str] if "grid_xt" in ds.variables: hdims = ["grid_xt", "grid_yt"] elif "lon" in ds.variables: @@ -112,10 +104,10 @@ def load_series_data( ds: xr.Dataset, names: List[str], time_dim: Hashable, - spatial_dims: HorizontalCoordinates, + spatial_dim_names: List[str], ): time_slice = slice(idx, idx + n_steps) - dims = [time_dim] + spatial_dims.dims + dims = [time_dim] + spatial_dim_names shape = [n_steps] + [ds.sizes[spatial_dim] for spatial_dim in dims[1:]] loaded = _load_all_variables(ds, names, time_slice) arrays = {} @@ -125,70 +117,21 @@ def load_series_data( return arrays -def get_horizontal_dimensions(ds: xr.Dataset) -> List[np.ndarray]: +def get_horizontal_dimensions( + ds: xr.Dataset, dtype: Optional[torch.dtype] +) -> List[torch.Tensor]: hdims = infer_horizontal_dimension_names(ds) horizontal_values = [] for dim in hdims: if dim in ds: - horizontal_values.append(np.array(ds[dim].values, dtype=np.float32)) + horizontal_values.append(torch.tensor(ds[dim].values, dtype=dtype)) else: raise ValueError(f"Expected {dim} in dataset: {ds}.") return horizontal_values -def get_times(ds: xr.Dataset, start: int, n_steps: int) -> xr.DataArray: - """ - Get the time coordinate segment from the dataset, check that it's a - cftime.datetime object, and return it is a data array (not a coordinate), - so that it can be concatenated with other samples' times. - """ - time_segment = ds["time"][slice(start, start + n_steps)] - assert isinstance( - time_segment[0].item(), cftime.datetime - ), "time must be cftime.datetime." - if len(time_segment) != n_steps: - raise ValueError( - f"Expected {n_steps} time steps, but got {len(time_segment)} instead." - ) - return time_segment.drop_vars(["time"]) - - -@dataclasses.dataclass -class BatchData: - """A container for the data and time coordinates of a batch. - - Attributes: - data: Data for each variable in each sample, concatenated along samples - to make a batch. To be used directly in training, validation, and - inference. - times: An array of times for each sample in the batch, concatenated along - samples to make a batch. To be used in writing out inference - predictions with time coordinates, not directly in ML. - - """ - - data: TensorMapping - times: xr.DataArray - - @classmethod - def from_sample_tuples( - cls, - samples: Sequence[Tuple[TensorMapping, xr.DataArray]], - sample_dim_name: str = "sample", - ) -> "BatchData": - """ - Collate function for use with PyTorch DataLoader. Needed since samples contain - both tensor mapping and xarray time coordinates, the latter of which we do - not want to convert to tensors. - """ - sample_data, sample_times = zip(*samples) - batch_data = default_collate(sample_data) - batch_times = xr.concat(sample_times, dim=sample_dim_name) - return cls(batch_data, batch_times) - - def decode_timestep(microseconds: int) -> datetime.timedelta: return datetime.timedelta(microseconds=microseconds) diff --git a/fme/fme/core/data_loading/_xarray.py b/fme/fme/core/dataset/xarray.py similarity index 60% rename from fme/fme/core/data_loading/_xarray.py rename to fme/fme/core/dataset/xarray.py index 6d6a85e..82ef2a6 100644 --- a/fme/fme/core/data_loading/_xarray.py +++ b/fme/fme/core/dataset/xarray.py @@ -13,24 +13,21 @@ import torch import xarray as xr -import fme -from fme.core import metrics -from fme.core.typing_ import TensorDict - -from .config import Slice, TimeSlice, XarrayDataConfig -from .data_typing import ( - Dataset, +from fme.core.coordinates import ( HEALPixCoordinates, HorizontalCoordinates, + HybridSigmaPressureCoordinate, LatLonCoordinates, - SigmaCoordinates, - VariableMetadata, ) +from fme.core.device import get_device +from fme.core.typing_ import Slice, TensorDict + +from .config import RepeatedInterval, TimeSlice, XarrayDataConfig +from .data_typing import Dataset, VariableMetadata from .requirements import DataRequirements from .utils import ( as_broadcasted_tensor, get_horizontal_dimensions, - get_times, infer_horizontal_dimension_names, load_series_data, ) @@ -48,30 +45,19 @@ ) -def subset_dataset(dataset: Dataset, subset: slice) -> Dataset: - """Returns a subset of the dataset and propagates other properties.""" - indices = range(len(dataset))[subset] - logging.info(f"Subsetting dataset samples according to {subset}.") - subsetted_dataset = torch.utils.data.Subset(dataset, indices) - subsetted_dataset.metadata = dataset.metadata - subsetted_dataset.area_weights = dataset.area_weights - subsetted_dataset.sigma_coordinates = dataset.sigma_coordinates - subsetted_dataset.horizontal_coordinates = dataset.horizontal_coordinates - subsetted_dataset.timestep = dataset.timestep - subsetted_dataset.is_remote = dataset.is_remote - subsetted_dataset.sample_start_times = dataset.sample_start_times[subset] - return subsetted_dataset - - -def get_sigma_coordinates(ds: xr.Dataset) -> SigmaCoordinates: +def _get_vertical_coordinates( + ds: xr.Dataset, dtype: Optional[torch.dtype] +) -> HybridSigmaPressureCoordinate: """ - Get sigma coordinates from a dataset. + Get hybrid sigma-pressure coordinates from a dataset. Assumes that the dataset contains variables named `ak_N` and `bk_N` where `N` is the level number. The returned tensors are sorted by level number. Args: - ds: Dataset to get sigma coordinates from. + ds: Dataset to get vertical coordinates from. + dtype: Data type of the returned tensors. If None, the dtype is not + changed from the original in ds. """ ak_mapping = { int(v[3:]): torch.as_tensor(ds[v].values) @@ -88,20 +74,14 @@ def get_sigma_coordinates(ds: xr.Dataset) -> SigmaCoordinates: if len(ak_list) == 0 or len(bk_list) == 0: logger.warning("Dataset does not contain ak and bk coordinates.") - return SigmaCoordinates( - ak=torch.tensor([], device=fme.get_device()), - bk=torch.tensor([], device=fme.get_device()), + return HybridSigmaPressureCoordinate( + ak=torch.tensor([]), + bk=torch.tensor([]), ) - if len(ak_list) != len(bk_list): - raise ValueError( - "Expected same number of ak and bk coordinates, " - f"got {len(ak_list)} and {len(bk_list)}." - ) - - return SigmaCoordinates( - ak=torch.as_tensor(ak_list, device=fme.get_device(), dtype=torch.float), - bk=torch.as_tensor(bk_list, device=fme.get_device(), dtype=torch.float), + return HybridSigmaPressureCoordinate( + ak=torch.as_tensor(ak_list, dtype=dtype), + bk=torch.as_tensor(bk_list, dtype=dtype), ) @@ -113,26 +93,26 @@ def get_raw_times(paths: List[str], engine: str) -> List[np.ndarray]: return times -def repeat_and_increment_times( +def repeat_and_increment_time( raw_times: List[np.ndarray], n_repeats: int, timestep: datetime.timedelta ) -> List[np.ndarray]: """Repeats and increments a collection of arrays of evenly spaced times.""" n_timesteps = sum(len(times) for times in raw_times) timespan = timestep * n_timesteps - repeated_and_incremented_times = [] + repeated_and_incremented_time = [] for repeats in range(n_repeats): increment = repeats * timespan - for times in raw_times: - incremented_times = times + increment - repeated_and_incremented_times.append(incremented_times) - return repeated_and_incremented_times + for time in raw_times: + incremented_time = time + increment + repeated_and_incremented_time.append(incremented_time) + return repeated_and_incremented_time -def get_cumulative_timesteps(times: List[np.ndarray]) -> np.ndarray: - """Returns a list of cumulative timesteps for each item in times.""" +def get_cumulative_timesteps(time: List[np.ndarray]) -> np.ndarray: + """Returns a list of cumulative timesteps for each item in a time coordinate.""" num_timesteps_per_file = [0] - for time_coord in times: + for time_coord in time: num_timesteps_per_file.append(len(time_coord)) return np.array(num_timesteps_per_file).cumsum() @@ -261,18 +241,65 @@ def _open_file_fh_cached(path, **kwargs): protocol = _get_protocol(path) if protocol: # add an LRU cache for remote zarrs - fn = _open_xr_dataset_lru - else: - # netcdf4 and h5engine have a filehandle LRU cache in xarray - # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501 - fn = _open_xr_dataset - - return fn( + return _open_xr_dataset_lru( + path, + **kwargs, + ) + # netcdf4 and h5engine have a filehandle LRU cache in xarray + # https://github.com/pydata/xarray/blob/cd3ab8d5580eeb3639d38e1e884d2d9838ef6aa1/xarray/backends/file_manager.py#L54 # noqa: E501 + return _open_xr_dataset( path, **kwargs, ) +class DatasetProperties: + def __init__( + self, + variable_metadata: Mapping[str, VariableMetadata], + vertical_coordinate: HybridSigmaPressureCoordinate, + horizontal_coordinates: HorizontalCoordinates, + timestep: datetime.timedelta, + is_remote: bool, + ): + self.variable_metadata = variable_metadata + self.vertical_coordinate = vertical_coordinate + self.horizontal_coordinates = horizontal_coordinates + self.timestep = timestep + self.is_remote = is_remote + + def to_device(self) -> "DatasetProperties": + device = get_device() + return DatasetProperties( + self.variable_metadata, + self.vertical_coordinate.to(device), + self.horizontal_coordinates.to(device), + self.timestep, + self.is_remote, + ) + + def update(self, other: "DatasetProperties"): + self.is_remote = self.is_remote or other.is_remote + if self.timestep != other.timestep: + raise ValueError("Inconsistent timesteps between datasets") + if self.variable_metadata != other.variable_metadata: + raise ValueError("Inconsistent metadata between datasets") + if self.vertical_coordinate != other.vertical_coordinate: + raise ValueError("Inconsistent vertical coordinates between datasets") + if self.horizontal_coordinates != other.horizontal_coordinates: + raise ValueError("Inconsistent horizontal coordinates between datasets") + + +def get_xarray_dataset( + config: XarrayDataConfig, requirements: DataRequirements +) -> Tuple["Dataset", DatasetProperties]: + dataset = XarrayDataset(config, requirements) + properties = dataset.properties + index_slice = as_index_selection(config.subset, dataset) + dataset = dataset.subset(index_slice) + return dataset, properties + + class XarrayDataset(Dataset): """Load data from a directory of files matching a pattern using xarray. The number of contiguous timesteps to load for each sample is specified by @@ -280,7 +307,8 @@ class XarrayDataset(Dataset): For example, if the file(s) have the time coordinate (t0, t1, t2, t3, t4) and requirements.n_timesteps=3, then this dataset will - provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4).""" + provide three samples: (t0, t1, t2), (t1, t2, t3), and (t2, t3, t4). + """ def __init__( self, @@ -288,17 +316,20 @@ def __init__( requirements: DataRequirements, ): self._horizontal_coordinates: HorizontalCoordinates - self.names = requirements.names + self.renamed_variables = config.renamed_variables or {} + self._names = self._get_names_to_load(requirements.names) self.path = config.data_path self.file_pattern = config.file_pattern - self.engine = "netcdf4" if config.engine is None else config.engine - self._default_file_pattern_check() + self.engine = config.engine + self.dtype = config.torch_dtype + self.spatial_dimensions = config.spatial_dimensions + self.fill_nans = config.fill_nans fs = _get_fs(self.path) glob_paths = sorted(fs.glob(os.path.join(self.path, config.file_pattern))) self._raw_paths = _preserve_protocol(self.path, glob_paths) if len(self._raw_paths) == 0: raise ValueError( - f"No files found matching '{self.path}/{config.file_pattern}'." + f"No files found matching '{self.path}/{self.file_pattern}'." ) self.full_paths = self._raw_paths * config.n_repeats self.n_steps = requirements.n_timesteps # one input, n_steps - 1 outputs @@ -310,15 +341,34 @@ def __init__( ) ( self._horizontal_coordinates, - self._area_weights, self._static_derived_data, - ) = self.configure_horizontal_coordinates(config, first_dataset) + ) = self.configure_horizontal_coordinates(first_dataset) ( - self.time_dependent_names, - self.time_invariant_names, - self.static_derived_names, + self._time_dependent_names, + self._time_invariant_names, + self._static_derived_names, ) = self._group_variable_names_by_time_type() - self._sigma_coordinates = get_sigma_coordinates(first_dataset) + self._vertical_coordinates = _get_vertical_coordinates( + first_dataset, self.dtype + ) + self.overwrite = config.overwrite + + @property + def properties(self) -> DatasetProperties: + return DatasetProperties( + self._variable_metadata, + self._vertical_coordinates, + self._horizontal_coordinates, + self.timestep, + self.is_remote, + ) + + def _get_names_to_load(self, names: List[str]) -> List[str]: + # requirements.names from stepper config refer to the final set of + # variables after any renaming occurs. This returns the set of names + # to load from data before renaming. + inverted_renaming = {v: k for k, v in self.renamed_variables.items()} + return [inverted_renaming.get(n, n) for n in names] @property def horizontal_coordinates(self) -> HorizontalCoordinates: @@ -333,20 +383,12 @@ def is_remote(self) -> bool: @property def all_times(self) -> xr.CFTimeIndex: - """Time index of all available times in the data""" + """Time index of all available times in the data.""" return self._all_times - def _default_file_pattern_check(self): - if self.engine == "zarr" and self.file_pattern == "*.nc": - raise ValueError( - "The file pattern is set to the default NetCDF file pattern *.nc " - "but the engine is specified as 'zarr'. Please set " - "`XarrayDataConfig.file_pattern` to match the zarr filename." - ) - - def _get_metadata(self, ds): + def _get_variable_metadata(self, ds): result = {} - for name in self.names: + for name in self._names: if name in StaticDerivedData.names: result[name] = StaticDerivedData.metadata[name] elif hasattr(ds[name], "units") and hasattr(ds[name], "long_name"): @@ -354,7 +396,7 @@ def _get_metadata(self, ds): units=ds[name].units, long_name=ds[name].long_name, ) - self._metadata = result + self._variable_metadata = result def _get_files_stats(self, n_repeats: int, infer_timestep: bool): logging.info(f"Opening data at {os.path.join(self.path, self.file_pattern)}") @@ -363,44 +405,32 @@ def _get_files_stats(self, n_repeats: int, infer_timestep: bool): self._timestep: Optional[datetime.timedelta] if infer_timestep: self._timestep = get_timestep(np.concatenate(raw_times)) - time_coords = repeat_and_increment_times( - raw_times, n_repeats, self.timestep - ) + time_coord = repeat_and_increment_time(raw_times, n_repeats, self.timestep) else: self._timestep = None - time_coords = raw_times + time_coord = raw_times - cum_num_timesteps = get_cumulative_timesteps(time_coords) + cum_num_timesteps = get_cumulative_timesteps(time_coord) self.start_indices = cum_num_timesteps[:-1] self.total_timesteps = cum_num_timesteps[-1] self._n_initial_conditions = self.total_timesteps - self.n_steps + 1 - self._sample_start_times = xr.CFTimeIndex( - np.concatenate(time_coords)[: self._n_initial_conditions] + self._sample_start_time = xr.CFTimeIndex( + np.concatenate(time_coord)[: self._n_initial_conditions] ) - self._all_times = xr.CFTimeIndex(np.concatenate(raw_times)) + self._all_times = xr.CFTimeIndex(np.concatenate(time_coord)) - del cum_num_timesteps, time_coords + del cum_num_timesteps, time_coord ds = self._open_file(0) - self._get_metadata(ds) + self._get_variable_metadata(ds) - for i in range(len(self.names)): - if self.names[i] in ds.variables: - img_shape = ds[self.names[i]].shape[-2:] - break - else: - raise ValueError( - f"None of the requested variables {self.names} are present " - f"in the dataset." - ) logging.info(f"Found {self._n_initial_conditions} samples.") - logging.info(f"Image shape is {img_shape[0]} x {img_shape[1]}.") - logging.info(f"Following variables are available: {list(ds.variables)}.") def _group_variable_names_by_time_type(self) -> VariableNames: """Returns lists of time-dependent variable names, time-independent variable names, and variables which are only present as an initial - condition.""" + condition. + """ ( time_dependent_names, time_invariant_names, @@ -410,72 +440,81 @@ def _group_variable_names_by_time_type(self) -> VariableNames: # fields a time dimension. We assume that all fields are present in the # netcdf file corresponding to the first chunk of time. with _open_xr_dataset(self.full_paths[0], engine=self.engine) as ds: - for name in self.names: + for name in self._names: if name in StaticDerivedData.names: static_derived_names.append(name) else: - dims = ds[name].dims - if "time" in dims: - time_dependent_names.append(name) + try: + da = ds[name] + except KeyError: + raise ValueError( + f"Required variable not found in dataset: {name}." + ) else: - time_invariant_names.append(name) + dims = da.dims + if "time" in dims: + time_dependent_names.append(name) + else: + time_invariant_names.append(name) + logging.info( + f"The required variables have been found in the dataset: {self._names}." + ) + return VariableNames( time_dependent_names, time_invariant_names, static_derived_names, ) - def configure_horizontal_coordinates(self, config, first_dataset): + def configure_horizontal_coordinates( + self, first_dataset + ) -> Tuple[HorizontalCoordinates, StaticDerivedData]: horizontal_coordinates: HorizontalCoordinates - area_weights: torch.Tensor static_derived_data: StaticDerivedData - dims = get_horizontal_dimensions(first_dataset) + dims = get_horizontal_dimensions(first_dataset, self.dtype) - if config.spatial_dimensions == "latlon": + if self.spatial_dimensions == "latlon": lons = dims[0] lats = dims[1] names = infer_horizontal_dimension_names(first_dataset) lon_name = names[0] lat_name = names[1] horizontal_coordinates = LatLonCoordinates( - lon=torch.as_tensor(lons, device=fme.get_device()), - lat=torch.as_tensor(lats, device=fme.get_device()), - lat_name=lat_name, - lon_name=lon_name, + lon=lons, + lat=lats, + loaded_lat_name=lat_name, + loaded_lon_name=lon_name, ) - area_weights = metrics.spherical_area_weights(lats, len(lons)) static_derived_data = StaticDerivedData(horizontal_coordinates) - elif config.spatial_dimensions == "healpix": + elif self.spatial_dimensions == "healpix": face = dims[0] height = dims[1] width = dims[2] horizontal_coordinates = HEALPixCoordinates( - face=torch.as_tensor(face, device=fme.get_device()), - height=torch.as_tensor(height, device=fme.get_device()), - width=torch.as_tensor(width, device=fme.get_device()), + face=face, + height=height, + width=width, ) - # Area weights should be all 1's of shape (face, height, width), - # since area is uniform. - area_weights = torch.ones((len(face), len(height), len(width))) static_derived_data = StaticDerivedData(horizontal_coordinates) else: raise ValueError( - f"unexpected config.spatial_dimensions {config.spatial_dimensions}," + f"unexpected config.spatial_dimensions {self.spatial_dimensions}," " should be one of 'latlon' or 'healpix'" ) - return horizontal_coordinates, area_weights, static_derived_data + coords_sizes = { + coord_name: len(coord) + for coord_name, coord in horizontal_coordinates.coords.items() + } + logging.info(f"Horizontal coordinate sizes are {coords_sizes}.") + return horizontal_coordinates, static_derived_data @property - def area_weights(self) -> torch.Tensor: - return self._area_weights + def variable_metadata(self) -> Mapping[str, VariableMetadata]: + return self._variable_metadata @property - def metadata(self) -> Mapping[str, VariableMetadata]: - return self._metadata - - @property - def sigma_coordinates(self) -> SigmaCoordinates: - return self._sigma_coordinates + def vertical_coordinate(self) -> HybridSigmaPressureCoordinate: + return self._vertical_coordinates @property def timestep(self) -> datetime.timedelta: @@ -496,9 +535,9 @@ def _open_file(self, idx): return _open_file_fh_cached(self.full_paths[idx], engine=self.engine) @property - def sample_start_times(self) -> xr.CFTimeIndex: + def sample_start_time(self) -> xr.CFTimeIndex: """Return cftime index corresponding to start time of each sample.""" - return self._sample_start_times + return self._sample_start_time def __getitem__(self, idx: int) -> Tuple[TensorDict, xr.DataArray]: """Return a sample of data spanning the timesteps [idx, idx + self.n_steps). @@ -525,26 +564,26 @@ def get_sample_by_time_slice( # get the sequence of observations arrays: Dict[str, List[torch.Tensor]] = {} - times_segments: List[xr.DataArray] = [] idxs = range(input_file_idx, output_file_idx + 1) total_steps = 0 for i, file_idx in enumerate(idxs): ds = self._open_file(file_idx) + if self.fill_nans is not None: + ds = ds.fillna(self.fill_nans.value) start = input_local_idx if i == 0 else 0 stop = output_local_idx if i == len(idxs) - 1 else len(ds["time"]) - 1 n_steps = stop - start + 1 total_steps += n_steps tensor_dict = load_series_data( - start, - n_steps, - ds, - self.time_dependent_names, - "time", - self._horizontal_coordinates, + idx=start, + n_steps=n_steps, + ds=ds, + names=self._time_dependent_names, + time_dim="time", + spatial_dim_names=self._horizontal_coordinates.loaded_dims, ) - for n in self.time_dependent_names: + for n in self._time_dependent_names: arrays.setdefault(n, []).append(tensor_dict[n]) - times_segments.append(get_times(ds, start, n_steps)) ds.close() del ds @@ -552,49 +591,77 @@ def get_sample_by_time_slice( for n, tensor_list in arrays.items(): tensors[n] = torch.cat(tensor_list) del arrays - times: xr.DataArray = xr.concat(times_segments, dim="time") # load time-invariant variables from first dataset - if len(self.time_invariant_names) > 0: + if len(self._time_invariant_names) > 0: ds = self._open_file(idxs[0]) - dims = ["time"] + self._horizontal_coordinates.dims + dims = ["time"] + self._horizontal_coordinates.loaded_dims shape = [total_steps] + [ds.sizes[dim] for dim in dims[1:]] - for name in self.time_invariant_names: + for name in self._time_invariant_names: variable = ds[name].variable tensors[name] = as_broadcasted_tensor(variable, dims, shape) ds.close() del ds # load static derived variables - for name in self.static_derived_names: + for name in self._static_derived_names: tensor = self._static_derived_data[name] tensors[name] = tensor.repeat((total_steps, 1, 1)) - return tensors, times + # cast to desired dtype + tensors = {k: v.to(dtype=self.dtype) for k, v in tensors.items()} + + # apply renaming + for original_name, new_name in self.renamed_variables.items(): + tensors[new_name] = tensors.pop(original_name) + # Apply field overwrites + tensors = self.overwrite.apply(tensors) -def as_index_slice(subset: Union[Slice, TimeSlice], dataset: XarrayDataset) -> slice: + # Create a DataArray of times to return corresponding to the slice that + # is valid even when n_repeats > 1. + time = xr.DataArray(self.all_times[time_slice].values, dims=["time"]) + + return tensors, time + + def subset(self, subset: Union[slice, torch.Tensor]) -> Dataset: + """Returns a subset of the dataset and propagates other properties.""" + indices = range(len(self))[subset] + logging.info(f"Subsetting dataset samples according to {subset}.") + subsetted_dataset = torch.utils.data.Subset(self, indices) + return subsetted_dataset + + +def as_index_selection( + subset: Union[Slice, TimeSlice, RepeatedInterval], dataset: XarrayDataset +) -> Union[slice, np.ndarray]: """Converts a subset defined either as a Slice or TimeSlice into an index slice - based on time coordinate in provided dataset.""" + based on time coordinate in provided dataset. + """ if isinstance(subset, Slice): - index_slice = subset.slice + index_selection = subset.slice elif isinstance(subset, TimeSlice): - index_slice = subset.slice(dataset.sample_start_times) + index_selection = subset.slice(dataset.sample_start_time) + elif isinstance(subset, RepeatedInterval): + try: + index_selection = subset.get_boolean_mask(len(dataset), dataset.timestep) + except ValueError as e: + raise ValueError(f"Error when applying RepeatedInterval to dataset: {e}") else: raise TypeError(f"subset must be Slice or TimeSlice, got {type(subset)}") - return index_slice + return index_selection -def get_timestep(times: np.ndarray) -> datetime.timedelta: - """Computes the timestep of an array of times. +def get_timestep(time: np.ndarray) -> datetime.timedelta: + """Computes the timestep of an array of a time coordinate array. Raises an error if the times are not separated by a positive constant interval, or if the array has one or fewer times. """ - assert len(times.shape) == 1, "times must be a 1D array" + assert len(time.shape) == 1, "times must be a 1D array" - if len(times) > 1: - timesteps = np.diff(times) + if len(time) > 1: + timesteps = np.diff(time) timestep = timesteps[0] if not (timestep > datetime.timedelta(days=0)): diff --git a/fme/fme/core/device.py b/fme/fme/core/device.py index 0c24044..f0dfab3 100644 --- a/fme/fme/core/device.py +++ b/fme/fme/core/device.py @@ -2,6 +2,8 @@ import torch +from .typing_ import TensorDict, TensorMapping + def using_gpu() -> bool: return get_device().type == "cuda" @@ -9,12 +11,18 @@ def using_gpu() -> bool: def get_device() -> torch.device: """If CUDA is available, return a CUDA device. Otherwise, return a CPU device - unless FME_USE_MPS is set, in which case return an MPS device if available.""" + unless FME_USE_MPS is set, in which case return an MPS device if available. + """ if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) else: mps_available = torch.backends.mps.is_available() if mps_available and os.environ.get("FME_USE_MPS", "0") == "1": - return torch.device("mps") + return torch.device("mps", 0) else: return torch.device("cpu") + + +def move_tensordict_to_device(data: TensorMapping) -> TensorDict: + device = get_device() + return {name: value.to(device) for name, value in data.items()} diff --git a/fme/fme/core/dicts.py b/fme/fme/core/dicts.py index 1c2ffda..1b07b78 100644 --- a/fme/fme/core/dicts.py +++ b/fme/fme/core/dicts.py @@ -5,9 +5,8 @@ def to_flat_dict(d: Mapping[str, Any]) -> Dict[str, Any]: """ Converts any nested dictionaries to a flat version with the nested keys joined with a '.', e.g., {a: {b: 1}} -> - {a.b: 1} + {a.b: 1}. """ - new_flat = {} for k, v in d.items(): if isinstance(v, dict): @@ -23,9 +22,8 @@ def to_flat_dict(d: Mapping[str, Any]) -> Dict[str, Any]: def to_nested_dict(d: Mapping[str, Any]) -> Dict[str, Any]: """ Converts a flat dictionary with '.' joined keys back into - a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}} + a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}}. """ - new_config: Dict[str, Any] = {} for k, v in d.items(): diff --git a/fme/fme/core/distributed.py b/fme/fme/core/distributed.py index eb12cb2..5df8151 100644 --- a/fme/fme/core/distributed.py +++ b/fme/fme/core/distributed.py @@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Union import torch.distributed +from torch.nn import SyncBatchNorm from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel @@ -40,7 +41,7 @@ class Distributed: variables without having to pass them around, and lets us put the initialization for this global state in the same place as the routines that use it. - Attributes: + Parameters: world_size: The number of processes in the distributed training job. rank: The global rank of the current process. local_rank: The node-local rank of the current process. @@ -179,7 +180,6 @@ def gather_irregular( A list of tensors of consistent shape, where the i-th element is the tensor from the i-th process. """ - return gather_irregular( tensor, self.reduce_max, @@ -212,7 +212,7 @@ def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: device_ids = None output_device = None return DistributedDataParallel( - module, + SyncBatchNorm.convert_sync_batchnorm(module), device_ids=device_ids, output_device=output_device, ) @@ -244,7 +244,6 @@ def gather_irregular( Returns: A list of tensors, where the i-th element is the tensor from the i-th process. """ - output_tensor_size = [] tensor_size = list(tensor.size()) for dim_len in tensor_size: @@ -281,7 +280,7 @@ def pad_tensor_at_end( fill_value: Union[float, int] = 0.0, ): """Pad tensor by specified amount at end of each dimension. - Note that `pad` format is in reverse dimension order + Note that `pad` format is in reverse dimension order. Args: tensor: The tensor to pad @@ -311,7 +310,7 @@ def pad_tensor_at_end( def unpad_tensor_at_end( tensor: torch.Tensor, dimension_difference: torch.Tensor ) -> torch.Tensor: - """Remove padding from tensor + """Remove padding from tensor. Args: tensor: The tensor to remove padding from diff --git a/fme/fme/core/ema.py b/fme/fme/core/ema.py index dd1e336..8966b72 100644 --- a/fme/fme/core/ema.py +++ b/fme/fme/core/ema.py @@ -1,5 +1,5 @@ """ -Exponential Moving Average (EMA) module +Exponential Moving Average (EMA) module. Copied from https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/ema.py and modified. @@ -37,8 +37,7 @@ class HasNamedParameters(Protocol): def named_parameters( self, recurse: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - ... + ) -> Iterator[Tuple[str, nn.Parameter]]: ... @dataclasses.dataclass @@ -46,7 +45,7 @@ class EMAConfig: """ Configuration for exponential moving average of model weights. - Attributes: + Parameters: decay: decay rate for the moving average """ @@ -198,7 +197,7 @@ def from_state(cls, state, model) -> "EMATracker": Returns: The EMA tracker. """ - ema = cls(model, state["decay"], state["faster_decay_at_start"]) + ema = cls(model, float(state["decay"]), state["faster_decay_at_start"]) ema.num_updates = state["num_updates"] ema._module_name_to_ema_name = state["module_name_to_ema_name"] return ema diff --git a/fme/fme/core/generics/aggregator.py b/fme/fme/core/generics/aggregator.py new file mode 100644 index 0000000..146eefd --- /dev/null +++ b/fme/fme/core/generics/aggregator.py @@ -0,0 +1,53 @@ +import abc +from typing import Any, Dict, Generic, List, TypeVar + +PS = TypeVar("PS", contravariant=True) # prognostic state +T = TypeVar("T", contravariant=True) + + +class AggregatorABC(abc.ABC, Generic[T]): + @abc.abstractmethod + def record_batch(self, batch: T) -> None: + pass + + @abc.abstractmethod + def get_logs(self, label: str) -> Dict[str, float]: + pass + + +InferenceLog = Dict[str, Any] +InferenceLogs = List[InferenceLog] + + +class InferenceAggregatorABC(abc.ABC, Generic[PS, T]): + @abc.abstractmethod + def record_batch( + self, + data: T, + ) -> InferenceLogs: + """ + Record a batch of data. + + Args: + data: Batch of data. + + Returns: + Logs for the batch. + """ + pass + + @abc.abstractmethod + def record_initial_condition( + self, + initial_condition: PS, + ) -> InferenceLogs: + """ + Record the initial condition. + + May only be recorded once, before any calls to record_batch. + """ + pass + + @abc.abstractmethod + def get_summary_logs(self) -> InferenceLog: + pass diff --git a/fme/fme/core/generics/data.py b/fme/fme/core/generics/data.py new file mode 100644 index 0000000..da62d33 --- /dev/null +++ b/fme/fme/core/generics/data.py @@ -0,0 +1,68 @@ +import abc +from typing import Generic, Iterable, Protocol, Sized, TypeVar + +T = TypeVar("T", covariant=True) + + +class DataLoader(Protocol, Generic[T], Sized, Iterable[T]): + pass + + +PS = TypeVar("PS") # prognostic state +FD = TypeVar("FD", covariant=True) # forcing data + + +class InferenceDataABC(abc.ABC, Generic[PS, FD]): + @property + @abc.abstractmethod + def initial_condition(self) -> PS: ... + + @property + @abc.abstractmethod + def loader(self) -> DataLoader[FD]: ... + + +class SimpleInferenceData(InferenceDataABC[PS, FD]): + def __init__( + self, + initial_condition: PS, + loader: DataLoader[FD], + ): + self._initial_condition = initial_condition + self._loader = loader + + @property + def initial_condition(self) -> PS: + return self._initial_condition + + @property + def loader(self) -> DataLoader[FD]: + return self._loader + + +class GriddedDataABC(abc.ABC, Generic[T]): + @property + @abc.abstractmethod + def loader(self) -> DataLoader[T]: ... + + @property + @abc.abstractmethod + def n_samples(self) -> int: ... + + @property + @abc.abstractmethod + def n_batches(self) -> int: ... + + @property + @abc.abstractmethod + def batch_size(self) -> int: ... + + @abc.abstractmethod + def set_epoch(self, epoch: int): ... + + @abc.abstractmethod + def log_info(self, name: str): + """ + Report information about the data using logging.info. + """ + ... diff --git a/fme/fme/core/generics/inference.py b/fme/fme/core/generics/inference.py new file mode 100644 index 0000000..a8382d3 --- /dev/null +++ b/fme/fme/core/generics/inference.py @@ -0,0 +1,139 @@ +import logging +from typing import Callable, Generic, Iterator, Optional, Protocol, Tuple, TypeVar + +import torch + +from fme.core.generics.aggregator import InferenceAggregatorABC, InferenceLogs +from fme.core.generics.data import InferenceDataABC +from fme.core.generics.writer import NullDataWriter, WriterABC +from fme.core.timing import GlobalTimer +from fme.core.wandb import WandB + +PS = TypeVar("PS") # prognostic state +FD = TypeVar("FD", contravariant=True) # forcing data +SD = TypeVar("SD", covariant=True) # stepped data + + +class PredictFunction(Protocol, Generic[PS, FD, SD]): + def __call__( + self, + initial_condition: PS, + forcing: FD, + compute_derived_variables: bool = False, + ) -> Tuple[SD, PS]: ... + + +class Looper(Generic[PS, FD, SD]): + """ + Class for stepping a model forward arbitarily many times. + """ + + def __init__( + self, + predict: PredictFunction[PS, FD, SD], + data: InferenceDataABC[PS, FD], + ): + """ + Args: + predict: The prediction function to use. + data: The data to use. + """ + self._predict = predict + self._prognostic_state = data.initial_condition + self._len = len(data.loader) + self._loader = iter(data.loader) + + def __iter__(self) -> Iterator[SD]: + return self + + def __len__(self) -> int: + return self._len + + def __next__(self) -> SD: + """Return predictions for the time period corresponding to the next batch + of forcing data. Also returns the forcing data. + """ + timer = GlobalTimer.get_instance() + with timer.context("data_loading"): + try: + forcing_data = next(self._loader) + except StopIteration: + raise StopIteration + output_data, self._prognostic_state = self._predict( + self._prognostic_state, + forcing=forcing_data, + compute_derived_variables=True, + ) + return output_data + + def get_prognostic_state(self) -> PS: + return self._prognostic_state + + +def get_record_to_wandb(label: str = "") -> Callable[[InferenceLogs], None]: + wandb = WandB.get_instance() + step = 0 + + def record_logs(logs: InferenceLogs): + nonlocal step + for j, log in enumerate(logs): + if len(log) > 0: + if label != "": + log = {f"{label}/{k}": v for k, v in log.items()} + wandb.log(log, step=step + j) + step += len(logs) + + return record_logs + + +def run_inference( + predict: PredictFunction[PS, FD, SD], + data: InferenceDataABC[PS, FD], + aggregator: InferenceAggregatorABC[PS, SD], + writer: Optional[WriterABC[PS, SD]] = None, + record_logs: Optional[Callable[[InferenceLogs], None]] = None, +): + """Run extended inference loop given initial condition and forcing data. + + Args: + predict: The prediction function to use. + data: Provides an initial condition and appropriately aligned windows of + forcing data. + aggregator: Aggregator for collecting and reducing metrics. + writer: Data writer for saving the inference results to disk. + record_logs: Function for recording logs. By default, logs are recorded to + wandb. + """ + if record_logs is None: + record_logs = get_record_to_wandb(label="inference") + if writer is None: + writer = NullDataWriter() + timer = GlobalTimer.get_instance() + with torch.no_grad(): + looper = Looper(predict=predict, data=data) + with timer.context("aggregator"): + logs = aggregator.record_initial_condition( + initial_condition=data.initial_condition, + ) + with timer.context("wandb_logging"): + record_logs(logs) + with timer.context("data_writer"): + writer.write(data.initial_condition, "initial_condition.nc") + n_windows = len(looper) + for i, batch in enumerate(looper): + logging.info( + f"Inference: processing output from window {i + 1} of {n_windows}." + ) + with timer.context("data_writer"): + writer.append_batch( + batch=batch, + ) + with timer.context("aggregator"): + logs = aggregator.record_batch( + data=batch, + ) + with timer.context("wandb_logging"): + record_logs(logs) + with timer.context("data_writer"): + prognostic_state = looper.get_prognostic_state() + writer.write(prognostic_state, "restart.nc") diff --git a/fme/fme/core/generics/optimization.py b/fme/fme/core/generics/optimization.py new file mode 100644 index 0000000..4715c00 --- /dev/null +++ b/fme/fme/core/generics/optimization.py @@ -0,0 +1,50 @@ +import abc +import contextlib + +import torch +from torch import nn + + +class OptimizationABC(abc.ABC): + @contextlib.contextmanager + @abc.abstractmethod + def autocast(self): ... + + @property + @abc.abstractmethod + def learning_rate(self) -> float: ... + + @abc.abstractmethod + def set_mode(self, module: nn.Module): + """ + Sets the mode of the module to train. + """ + ... + + @abc.abstractmethod + def step_scheduler(self, valid_loss: float): + """ + Step the scheduler. + + Args: + valid_loss: The validation loss. Used in schedulers which change the + learning rate based on whether the validation loss is decreasing. + """ + ... + + @abc.abstractmethod + def step_weights(self, loss: torch.Tensor): ... + + @abc.abstractmethod + def get_state(self): + """ + Returns state as a serializable data structure. + """ + ... + + @abc.abstractmethod + def load_state(self, state): + """ + Loads state from a serializable data structure. + """ + ... diff --git a/fme/fme/core/generics/test_looper.py b/fme/fme/core/generics/test_looper.py new file mode 100644 index 0000000..04e671f --- /dev/null +++ b/fme/fme/core/generics/test_looper.py @@ -0,0 +1,503 @@ +import datetime +import unittest.mock +from collections import namedtuple +from typing import Callable, Iterable, Optional, Tuple + +import numpy as np +import pytest +import torch +import xarray as xr + +import fme +from fme.ace.data_loading.batch_data import BatchData, PrognosticState +from fme.ace.stepper import SingleModuleStepperConfig +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.device import get_device +from fme.core.generics.data import SimpleInferenceData +from fme.core.generics.inference import ( + Looper, + PredictFunction, + get_record_to_wandb, + run_inference, +) +from fme.core.gridded_ops import LatLonOperations +from fme.core.loss import WeightedMappingLossConfig +from fme.core.normalizer import NormalizationConfig +from fme.core.registry.module import ModuleSelector +from fme.core.testing.wandb import mock_wandb +from fme.core.timing import GlobalTimer +from fme.core.typing_ import TensorDict, TensorMapping + +SphericalData = namedtuple("SphericalData", ["data", "area_weights", "vertical_coord"]) + + +def get_data( + names: Iterable[str], shape: Tuple[int, int, int, int, int] +) -> SphericalData: + data = {} + n_lat = shape[2] + n_lon = shape[3] + shape_without_z = shape[:-1] + nz = shape[-1] + lats = torch.linspace(-89.5, 89.5, n_lat) + for name in names: + data[name] = torch.rand(*shape_without_z, device=fme.get_device()) + area_weights = fme.spherical_area_weights(lats, n_lon).to(fme.get_device()) + ak, bk = torch.arange(nz), torch.arange(nz) + vertical_coord = HybridSigmaPressureCoordinate(ak, bk) + return SphericalData(data, area_weights, vertical_coord) + + +def get_scalar_data(names, value): + return {n: np.array([value], dtype=np.float32) for n in names} + + +class MockLoader(torch.utils.data.DataLoader): + def __init__( + self, + shape: tuple, + names: Iterable[str], + n_windows: int, + time: Optional[xr.DataArray] = None, + ): + device = fme.get_device() + self._data = {n: torch.rand(*shape, device=device) for n in names} + if time is None: + self._time = xr.DataArray(np.zeros(shape[:2]), dims=["sample", "time"]) + elif time.shape != shape[:2]: + raise ValueError( + "Time shape must match the first two dimensions of the data." + ) + else: + self._time = time + self._n_windows = n_windows + self._current_window = 0 + + def __iter__(self): + return self + + def __len__(self) -> int: + return self._n_windows + + def __next__(self) -> BatchData: + if self._current_window < self._n_windows: + self._current_window += 1 + return BatchData.new_on_device( + data=self._data, + time=self._time + + (self._current_window - 1) * (self._time.shape[1] - 1), + ) + else: + raise StopIteration + + +def _get_stepper(): + class ChannelSum(torch.nn.Module): + def forward(self, x): + summed = torch.sum(x, dim=-3, keepdim=True) + diagnostic = torch.rand_like(summed) + return torch.concat([x, diagnostic], dim=-3) + + n_samples = 2 + n_time = 5 + n_lat = 2 + n_lon = 4 + nz = 3 + shape = (n_samples, n_time, n_lat, n_lon, nz) + in_names = ["forcing", "prognostic"] + out_names = ["prognostic", "diagnostic"] + all_names = list(set(in_names + out_names)) + + spherical_data = get_data(all_names, shape) + time = xr.DataArray( + np.repeat(np.expand_dims(np.arange(n_time), axis=0), n_samples, axis=0), + dims=["sample", "time"], + ) + + img_shape = spherical_data.data[in_names[0]].shape[2:] + gridded_operations = LatLonOperations(spherical_data.area_weights) + vertical_coordinate = spherical_data.vertical_coord + config = SingleModuleStepperConfig( + builder=ModuleSelector(type="prebuilt", config={"module": ChannelSum()}), + in_names=in_names, + out_names=out_names, + normalization=NormalizationConfig( + means=get_scalar_data(all_names, 0.0), + stds=get_scalar_data(all_names, 1.0), + ), + loss=WeightedMappingLossConfig(), + ) + stepper = config.get_stepper( + img_shape, + gridded_operations, + vertical_coordinate, + datetime.timedelta(seconds=1), + ) + return stepper, spherical_data, time, in_names, out_names + + +def test_looper(): + stepper, spherical_data, time, in_names, out_names = _get_stepper() + forcing_names = set(in_names) - set(out_names) + shape = spherical_data.data[in_names[0]].shape + initial_condition = BatchData.new_on_device( + data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, + time=time[:, 0:1], + ).get_start( + prognostic_names=stepper.prognostic_names, + n_ic_timesteps=1, + ) + loader = MockLoader(shape, forcing_names, 3, time=time) + looper = Looper( + predict=stepper.predict, + data=SimpleInferenceData(initial_condition, loader), + ) + + expected_output_shape = (shape[0], shape[1] - 1, shape[2], shape[3]) + for batch in looper: + assert set(out_names) == set(batch.data) + for name in out_names: + assert batch.data[name].shape == expected_output_shape + + +def test_looper_paired(): + stepper, spherical_data, time, in_names, out_names = _get_stepper() + forcing_names = set(in_names) - set(out_names) + shape = spherical_data.data[in_names[0]].shape + initial_condition = BatchData.new_on_device( + data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, + time=time[:, 0:1], + ).get_start( + prognostic_names=stepper.prognostic_names, + n_ic_timesteps=1, + ) + loader = MockLoader(shape, forcing_names, 3, time=time) + looper = Looper( + predict=stepper.predict_paired, + data=SimpleInferenceData(initial_condition, loader), + ) + + expected_output_shape = (shape[0], shape[1] - 1, shape[2], shape[3]) + for batch in looper: + assert set(out_names) == set(batch.prediction) + assert set(forcing_names) == set(batch.target) + for name in out_names: + assert batch.prediction[name].shape == expected_output_shape + for name in forcing_names: + assert batch.target[name].shape == expected_output_shape + + +def _mock_compute_derived_quantities(data, forcing_data): + data_name = list(data)[0] + forcing_name = list(forcing_data)[0] + derived = {"derived": data[data_name] + forcing_data[forcing_name]} + return {**data, **derived} + + +def test_looper_paired_with_derived_variables(): + mock_derive_func = unittest.mock.MagicMock( + side_effect=_mock_compute_derived_quantities + ) + stepper, spherical_data, time, in_names, out_names = _get_stepper() + stepper.derive_func = mock_derive_func + forcing_names = set(in_names) - set(out_names) + shape = spherical_data.data[in_names[0]].shape + initial_condition = BatchData.new_on_device( + {n: spherical_data.data[n][:, :1] for n in spherical_data.data}, + time=time[:, 0:1], + ).get_start( + prognostic_names=stepper.prognostic_names, + n_ic_timesteps=1, + ) + loader = MockLoader(shape, forcing_names, 2, time=time) + looper = Looper( + predict=stepper.predict_paired, + data=SimpleInferenceData(initial_condition, loader), + ) + + for batch in looper: + assert "derived" in batch.prediction + assert "derived" in batch.target + mock_derive_func.assert_called() + + +def test_looper_paired_with_target_data(): + stepper, spherical_data, time, in_names, out_names = _get_stepper() + all_names = list(set(in_names + out_names)) + shape = spherical_data.data[in_names[0]].shape + initial_condition = BatchData.new_on_device( + data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, + time=time[:, 0:1], + ).get_start( + prognostic_names=stepper.prognostic_names, + n_ic_timesteps=1, + ) + loader = MockLoader(shape, all_names, 2, time=time) + looper = Looper( + predict=stepper.predict_paired, + data=SimpleInferenceData(initial_condition, loader), + ) + + for batch in looper: + assert set(out_names) == set(batch.prediction) + assert set(all_names) == set(batch.target) + + +def test_looper_paired_with_target_data_and_derived_variables(): + mock_derive_func = unittest.mock.MagicMock( + side_effect=_mock_compute_derived_quantities + ) + stepper, spherical_data, time, in_names, out_names = _get_stepper() + stepper.derive_func = mock_derive_func + all_names = list(set(in_names + out_names)) + shape = spherical_data.data[in_names[0]].shape + initial_condition = BatchData.new_on_device( + data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, + time=time[:, 0:1], + ).get_start( + prognostic_names=stepper.prognostic_names, + n_ic_timesteps=1, + ) + loader = MockLoader(shape, all_names, 2, time=time) + looper = Looper( + predict=stepper.predict_paired, + data=SimpleInferenceData(initial_condition, loader), + ) + + for batch in looper: + assert set(out_names + ["derived"]) == set(batch.prediction) + assert set(all_names + ["derived"]) == set(batch.target) + mock_derive_func.assert_called() + + +def get_batch_data( + start_time, + n_timesteps, +): + n_samples = 1 + n_lat = 3 + n_lon = 4 + time_values = torch.arange( + start_time, start_time + n_timesteps, device=get_device() + )[None, :, None, None] + time_axis = torch.broadcast_to( + start_time + torch.arange(n_timesteps)[None, :], (n_samples, n_timesteps) + ) + time = xr.DataArray(time_axis, dims=["sample", "time"]) + return BatchData.new_on_device( + data={ + "var": torch.broadcast_to( + time_values, (n_samples, n_timesteps, n_lat, n_lon) + ) + }, + time=time, + ) + + +class PlusOneStepper: + def __init__( + self, + n_ic_timesteps: int, + derive_func: Optional[ + Callable[[TensorMapping, TensorMapping], TensorDict] + ] = None, + ): + self.n_ic_timesteps = n_ic_timesteps + if derive_func is None: + self.derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = ( + unittest.mock.MagicMock(side_effect=lambda x, y=None: dict(x)) + ) + else: + self.derive_func = derive_func + _: PredictFunction[ # for type checking + PrognosticState, + BatchData, + BatchData, + ] = self.predict + + def predict( + self, + initial_condition: PrognosticState, + forcing: BatchData, + compute_derived_variables: bool = False, + ) -> Tuple[BatchData, PrognosticState]: + ic_state = initial_condition.as_batch_data() + n_forward_steps = forcing.time.shape[1] - self.n_ic_timesteps + out_tensor = torch.zeros( + ic_state.data["var"].shape[0], + n_forward_steps, + *ic_state.data["var"].shape[2:], + device=ic_state.data["var"].device, + dtype=ic_state.data["var"].dtype, + ) + out_tensor[:, 0, ...] = ic_state.data["var"][:, -1, ...] + 1 + for i in range(1, n_forward_steps): + out_tensor[:, i, ...] = out_tensor[:, i - 1, ...] + 1 + data = BatchData.new_on_device( + data={"var": out_tensor}, + time=forcing.time[:, self.n_ic_timesteps :], + ) + if compute_derived_variables: + data = data.compute_derived_variables( + derive_func=self.derive_func, forcing_data=data + ) + return data, data.get_end(["var"], self.n_ic_timesteps) + + def get_forward_data( + self, + forcing: BatchData, + compute_derived_variables: bool = False, + ) -> BatchData: + if compute_derived_variables: + forcing = forcing.compute_derived_variables( + derive_func=self.derive_func, forcing_data=forcing + ) + return forcing.remove_initial_condition(self.n_ic_timesteps) + + +def test_looper_simple_batch_data(): + n_ic_timesteps = 1 + n_forward_steps = 2 + n_iterations = 10 + mock_derive_func = unittest.mock.MagicMock( + side_effect=lambda batch_data, forcing_data: batch_data + ) + stepper = PlusOneStepper( + n_ic_timesteps=n_ic_timesteps, derive_func=mock_derive_func + ) + initial_condition = get_batch_data( + 0, + n_timesteps=n_ic_timesteps, + ).get_start(prognostic_names=["var"], n_ic_timesteps=n_ic_timesteps) + loader = [ + get_batch_data( + i, + n_ic_timesteps + n_forward_steps, + ) + for i in range(0, n_iterations * n_forward_steps, n_forward_steps) + ] + + mock_predict = unittest.mock.MagicMock(side_effect=stepper.predict) + with unittest.mock.patch.object(stepper, "predict", mock_predict): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + looper = Looper( + predict=stepper.predict, + data=SimpleInferenceData(initial_condition, loader), + ) + for i, batch in enumerate(looper): + for j in range(batch.time.shape[1]): + assert torch.allclose( + batch.data["var"][:, j, ...], + torch.as_tensor(n_ic_timesteps + i * n_forward_steps + j), + ) + times = timer.get_durations() + assert times["data_loading"] > 0 + assert mock_derive_func.call_count == n_iterations + assert mock_predict.call_count == n_iterations + # we mocked out the implicit calls that happen in .predict + # if this changed, update the test + assert "forward_prediction" not in times + assert "compute_derived_variables" not in times + + +def get_mock_aggregator( + n_ic_timesteps: int, +) -> unittest.mock.MagicMock: + mock_aggregator = unittest.mock.MagicMock() + + # record_batch will start at step n_ic_timesteps + i = n_ic_timesteps + + def record_batch_side_effect( + data: BatchData, + ): + nonlocal i + ret = [{"step": j} for j in range(i, i + data.time.shape[1])] + i += data.time.shape[1] + return ret + + mock_aggregator = unittest.mock.MagicMock() + mock_aggregator.record_initial_condition = unittest.mock.MagicMock( + return_value=[{"step": j} for j in range(n_ic_timesteps)] + ) + + def get_summary_logs_side_effect(): + # we expect this gets called _outside_ of run_inference + raise ValueError("should not be called") + + mock_aggregator.get_summary_logs = unittest.mock.MagicMock( + side_effect=get_summary_logs_side_effect + ) + mock_aggregator.record_batch = unittest.mock.MagicMock( + side_effect=record_batch_side_effect + ) + return mock_aggregator + + +def get_mock_writer() -> unittest.mock.MagicMock: + mock_writer = unittest.mock.MagicMock() + return mock_writer + + +@pytest.mark.parametrize( + "n_ic_timesteps, n_forward_steps, n_iterations", + [ + pytest.param(1, 2, 5, id="n_ic_timesteps=1"), + pytest.param(2, 2, 5, id="n_ic_timesteps=2"), + ], +) +def test_run_inference_simple( + n_ic_timesteps: int, n_forward_steps: int, n_iterations: int +): + mock_derive_func = unittest.mock.MagicMock( + side_effect=lambda batch_data, forcing_data: batch_data + ) + stepper = PlusOneStepper( + n_ic_timesteps=n_ic_timesteps, derive_func=mock_derive_func + ) + initial_condition = get_batch_data( + 0, + n_timesteps=n_ic_timesteps, + ).get_start(prognostic_names=["var"], n_ic_timesteps=n_ic_timesteps) + loader = [ + get_batch_data( + i, + n_ic_timesteps + n_forward_steps, + ) + for i in range(0, n_iterations * n_forward_steps, n_forward_steps) + ] + mock_writer = get_mock_writer() + mock_aggregator = get_mock_aggregator(n_ic_timesteps) + + with GlobalTimer(): + with mock_wandb() as wandb: + wandb.configure(log_to_wandb=True) + record_logs = unittest.mock.MagicMock( + side_effect=get_record_to_wandb("inference") + ) # this init must be within mock_wandb context + run_inference( + predict=stepper.predict, + data=SimpleInferenceData(initial_condition, loader), + writer=mock_writer, + aggregator=mock_aggregator, + record_logs=record_logs, + ) + wandb_logs = wandb.get_logs() + timer = GlobalTimer.get_instance() + times = timer.get_durations() + assert times["wandb_logging"] > 0 + assert times["data_writer"] > 0 + assert times["aggregator"] > 0 + assert mock_writer.write.call_count == 2 + assert mock_aggregator.record_initial_condition.call_count == 1 + assert mock_writer.append_batch.call_count == n_iterations + assert mock_aggregator.record_batch.call_count == n_iterations + assert len(wandb_logs) == n_ic_timesteps + n_iterations * n_forward_steps + assert wandb_logs == [ + {"inference/step": i} + for i in range(n_ic_timesteps + n_iterations * n_forward_steps) + ] + assert ( + record_logs.call_count == n_iterations + 1 + ) # +1 for the initial condition diff --git a/fme/fme/core/generics/test_trainer.py b/fme/fme/core/generics/test_trainer.py new file mode 100644 index 0000000..71d6635 --- /dev/null +++ b/fme/fme/core/generics/test_trainer.py @@ -0,0 +1,638 @@ +import contextlib +import dataclasses +import itertools +import os +import unittest.mock +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, cast + +import numpy as np +import pytest +import torch + +from fme.ace.data_loading.batch_data import DataLoader +from fme.ace.stepper import TrainOutputABC, TrainStepperABC +from fme.core.ema import EMATracker +from fme.core.generics.aggregator import ( + AggregatorABC, + InferenceAggregatorABC, + InferenceLog, + InferenceLogs, +) +from fme.core.generics.data import GriddedDataABC, InferenceDataABC +from fme.core.generics.optimization import OptimizationABC +from fme.core.generics.trainer import ( + AggregatorBuilderABC, + CheckpointPaths, + TrainConfigProtocol, + Trainer, +) +from fme.core.optimization import Optimization +from fme.core.scheduler import SchedulerConfig +from fme.core.typing_ import Slice, TensorDict, TensorMapping + + +class PSType: + pass + + +class BDType: + pass + + +class FDType: + pass + + +class SDType: + pass + + +class TrainOutput(TrainOutputABC): + def get_metrics(self) -> Dict[str, torch.Tensor]: + return {} + + +class TrainData(GriddedDataABC[BDType]): + @property + def batch_size(self) -> int: + return 1 + + @property + def loader(self) -> DataLoader[BDType]: + return [BDType() for _ in range(self.n_batches)] + + @property + def n_samples(self) -> int: + return 3 + + @property + def n_batches(self) -> int: + return 5 + + @property + def n_forward_steps(self) -> int: + return 1 + + def __init__(self): + self._set_epoch = unittest.mock.MagicMock() + self._log_info = unittest.mock.MagicMock() + + def set_epoch(self, epoch: int) -> None: + self._set_epoch(epoch) + + def log_info(self, name: str) -> None: + self._log_info(name) + + @property + def set_epoch_mock(self) -> unittest.mock.Mock: + return self._set_epoch + + @property + def log_info_mock(self) -> unittest.mock.Mock: + return self._log_info + + +class InferenceData(InferenceDataABC[PSType, FDType]): + def __init__(self, n_time_windows: int = 1): + self.n_time_windows = n_time_windows + + @property + def initial_condition(self) -> PSType: + return PSType() + + @property + def loader(self) -> DataLoader[FDType]: + return [FDType() for _ in range(self.n_time_windows)] + + @property + def n_window_forward_steps(self) -> int: + return 1 + + +class TrainStepper(TrainStepperABC[PSType, BDType, FDType, SDType, TrainOutput]): + SelfType = TypeVar("SelfType", bound="TrainStepper") + + def __init__( + self, + state: Optional[Dict[str, Any]] = None, + ): + self._modules = torch.nn.ModuleList([torch.nn.Linear(1, 1, bias=False)]) + self._modules[0].weight.data.fill_(0.0) + if state is not None: + self._state = state + else: + self._state = {} + self.loaded_state: Optional[Dict[str, Any]] = None + + def get_state(self) -> Dict[str, Any]: + return {**self._state, "modules": self._modules.state_dict()} + + def load_state(self, state: Dict[str, Any]) -> None: + self._state = state + self.loaded_state = state + self._modules.load_state_dict(state["modules"]) + + @classmethod + def from_state(cls: Type[SelfType], state: Dict[str, Any]) -> SelfType: + ret = cls() + ret.load_state(state) + return ret + + @property + def modules(self) -> torch.nn.ModuleList: + return self._modules + + @property + def n_ic_timesteps(self) -> int: + return 1 + + def normalize(self, data: TensorMapping) -> TensorDict: + return dict(data) + + def predict_paired( + self, + initial_condition: PSType, + forcing: FDType, + compute_derived_variables: bool = False, + ) -> Tuple[SDType, PSType]: + return SDType(), PSType() + + def train_on_batch( + self, + batch: BDType, + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> TrainOutput: + optimization.step_weights(torch.tensor(float("inf"))) + return TrainOutput() + + +@dataclasses.dataclass +class Config: + experiment_dir: str = "test_experiment_dir" + checkpoint_dir: str = "test_checkpoint_dir" + max_epochs: int = 2 + save_checkpoint: bool = True + validate_using_ema: bool = True + log_train_every_n_batches: int = 1 + inference_n_forward_steps: int = 1 + checkpoint_save_epochs: Optional[Slice] = None + ema_checkpoint_save_epochs: Optional[Slice] = None + segment_epochs: Optional[int] = None + clean_wandb = unittest.mock.MagicMock() + + def __post_init__(self): + self.get_inference_epochs = unittest.mock.MagicMock( + return_value=[i for i in range(self.max_epochs)] + ) + + +_: TrainConfigProtocol = Config() + + +class TrainAggregator(AggregatorABC[TrainOutput]): + def __init__(self, train_loss: float): + self.train_loss = train_loss + + def record_batch(self, batch: TrainOutput) -> None: + pass + + def get_logs(self, label: str) -> Dict[str, Any]: + return {f"{label}/mean/loss": self.train_loss} + + +class ValidationAggregator(AggregatorABC[TrainOutput]): + def __init__(self, validation_loss: float): + self.validation_loss = validation_loss + + def record_batch(self, batch: TrainOutput) -> None: + pass + + def get_logs(self, label: str) -> Dict[str, Any]: + return {f"{label}/mean/loss": self.validation_loss} + + +class InferenceAggregator(InferenceAggregatorABC[PSType, SDType]): + def __init__(self, inference_loss: float): + self.inference_loss = inference_loss + + def record_batch(self, data: SDType) -> InferenceLogs: + return [{}] + + def record_initial_condition(self, initial_condition: PSType) -> InferenceLogs: + return [{}] + + def get_summary_logs(self) -> InferenceLog: + return {"time_mean_norm/rmse/channel_mean": self.inference_loss} + + +class AggregatorBuilder(AggregatorBuilderABC[PSType, TrainOutput, SDType]): + def __init__( + self, + train_losses: np.ndarray, + validation_losses: np.ndarray, + inference_losses: np.ndarray, + ): + self.train_losses = train_losses + self.validation_losses = validation_losses + self.inference_losses = inference_losses + self._train_calls = 0 + self._validation_calls = 0 + self._inference_calls = 0 + + def get_train_aggregator(self) -> AggregatorABC[TrainOutput]: + ret = TrainAggregator(self.train_losses[self._train_calls]) + self._train_calls += 1 + return ret + + def get_validation_aggregator(self) -> AggregatorABC[TrainOutput]: + ret = ValidationAggregator(self.validation_losses[self._validation_calls]) + self._validation_calls += 1 + return ret + + def get_inference_aggregator(self) -> InferenceAggregatorABC[PSType, SDType]: + ret = InferenceAggregator(self.inference_losses[self._inference_calls]) + self._inference_calls += 1 + return ret + + +def get_trainer( + tmp_path: str, + checkpoint_save_epochs: Optional[Slice] = None, + segment_epochs: Optional[int] = None, + max_epochs: int = 8, + checkpoint_dir: Optional[str] = None, + stepper_state: Optional[Dict[str, Any]] = None, + train_losses: Optional[np.ndarray] = None, + validation_losses: Optional[np.ndarray] = None, + inference_losses: Optional[np.ndarray] = None, + stepper_module_values: Optional[np.ndarray] = None, + ema_decay: float = 0.9999, + validate_using_ema: bool = True, +) -> Tuple[TrainConfigProtocol, Trainer]: + if checkpoint_dir is None: + checkpoint_dir = os.path.join(tmp_path, "checkpoints") + if train_losses is None: + train_losses = np.zeros(max_epochs) + if validation_losses is None: + validation_losses = np.zeros(max_epochs) + if inference_losses is None: + inference_losses = np.zeros(max_epochs) + if stepper_module_values is None: + stepper_module_values = np.zeros(max_epochs) + train_data = TrainData() + validation_data = TrainData() + inference_data = InferenceData() + stepper = TrainStepper(state=stepper_state) + + def build_optimization(modules: torch.nn.ModuleList) -> Optimization: + if len(modules) != 1: + raise ValueError("Expected 1 linear module with 1 weight") + if not isinstance(modules[0], torch.nn.Linear): + raise ValueError("Expected a linear module") + module = modules[0] + if module.weight.numel() != 1: + raise ValueError("Expected a linear module with 1 weight") + i = 0 + + opt = Optimization( + parameters=itertools.chain(*[module.parameters() for module in modules]), + optimizer_type="Adam", + lr=0.01, + max_epochs=max_epochs, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + original_step_scheduler = opt.step_scheduler + + def step_scheduler_side_effect(*args, **kwargs): + original_step_scheduler(*args, **kwargs) + nonlocal i + i += 1 + + opt.step_scheduler = unittest.mock.MagicMock( # type: ignore + side_effect=step_scheduler_side_effect + ) + + def step_weights_side_effect(*args, **kwargs): + if stepper_module_values is None: + raise ValueError("stepper_module_values is None") + module.weight.data.fill_(stepper_module_values[i]) + + opt.step_weights = unittest.mock.MagicMock(side_effect=step_weights_side_effect) # type: ignore + return opt + + def build_ema(modules: torch.nn.ModuleList) -> EMATracker: + return EMATracker(modules, decay=ema_decay) + + config: TrainConfigProtocol = Config( + experiment_dir=tmp_path, + checkpoint_dir=checkpoint_dir, + checkpoint_save_epochs=checkpoint_save_epochs, + segment_epochs=segment_epochs, + max_epochs=max_epochs, + validate_using_ema=validate_using_ema, + ) + aggregator_builder = AggregatorBuilder( + train_losses=train_losses, + validation_losses=validation_losses, + inference_losses=inference_losses, + ) + callback = unittest.mock.MagicMock() + return config, Trainer( + train_data=train_data, + validation_data=validation_data, + inference_data=inference_data, + stepper=stepper, + build_optimization=build_optimization, + build_ema=build_ema, + config=config, + aggregator_builder=aggregator_builder, + end_of_batch_callback=callback, + do_gc_collect=False, # for much faster tests + ) + + +@pytest.mark.parametrize( + "checkpoint_save_epochs", + [None, Slice(start=2, stop=3), Slice(start=1, step=2)], +) +def test_trainer(tmp_path: str, checkpoint_save_epochs: Optional[Slice]): + config, trainer = get_trainer(tmp_path, checkpoint_save_epochs, max_epochs=4) + trainer.train() + assert os.path.exists(config.experiment_dir) + assert os.path.exists(config.checkpoint_dir) + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.latest_checkpoint_path) + assert os.path.exists(paths.best_checkpoint_path) + assert os.path.exists(paths.best_inference_checkpoint_path) + assert os.path.exists(paths.ema_checkpoint_path) + save_epochs = list(range(config.max_epochs)) + if checkpoint_save_epochs is not None: + save_epochs = save_epochs[checkpoint_save_epochs.slice] + else: + save_epochs = [] + for i in range(config.max_epochs): + if i in save_epochs: + assert os.path.exists(paths.epoch_checkpoint_path(i)) + else: + assert not os.path.exists(paths.epoch_checkpoint_path(i)) + assert not os.path.exists(paths.ema_epoch_checkpoint_path(i)) + train_data = cast(TrainData, trainer.train_data) + valid_data = cast(TrainData, trainer.valid_data) + assert train_data.set_epoch_mock.mock_calls == [ + unittest.mock.call(i) for i in range(config.max_epochs) + ] + assert valid_data.set_epoch_mock.mock_calls == [] # no shuffling + assert train_data.log_info_mock.called + assert valid_data.log_info_mock.called + + +@pytest.mark.parametrize("segment_epochs", [1, 2, 3]) +def test_segmented_trainer_runs_correct_epochs(tmp_path: str, segment_epochs: int): + max_epochs = 4 + total_segments = max_epochs // segment_epochs + for i in range(total_segments): + config, trainer = get_trainer( + tmp_path, + checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all + # for speed, don't save per-epoch checkpoints for this test + checkpoint_save_epochs=Slice(start=0, stop=0), + segment_epochs=segment_epochs, + max_epochs=max_epochs, + ) + trainer.train() + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.latest_checkpoint_path) + train_data = cast(TrainData, trainer.train_data) + assert train_data.set_epoch_mock.mock_calls == [ + unittest.mock.call(i) + for i in range( + i * segment_epochs, + min((i + 1) * segment_epochs, config.max_epochs), + ) + ] + + +class TrainingInterrupted(Exception): + pass + + +@contextlib.contextmanager +def fail_after_calls_patch(object, method: str, call_count: int): + total_calls = 0 + original_method = getattr(object, method) + + def wrapper(*args, **kwargs): + nonlocal total_calls + total_calls += 1 + if total_calls >= call_count: + raise TrainingInterrupted() + return original_method(*args, **kwargs) + + with unittest.mock.patch.object(object, method) as mock: + mock.side_effect = wrapper + try: + yield mock + except TrainingInterrupted: + pass + + +@pytest.mark.parametrize( + "interrupt_method", + ["train_one_epoch", "validate_one_epoch", "inference_one_epoch"], +) +def test_resume_after_interrupted_training(tmp_path: str, interrupt_method: str): + max_epochs = 4 + calls_before_interrupt = 2 + stepper_state = {"foo": "bar"} + config, trainer = get_trainer( + tmp_path, + stepper_state=stepper_state, + checkpoint_save_epochs=Slice(start=0, stop=0), + max_epochs=max_epochs, + ) + with fail_after_calls_patch(trainer, interrupt_method, calls_before_interrupt): + trainer.train() + train_data = cast(TrainData, trainer.train_data) + assert train_data.set_epoch_mock.mock_calls == [ + unittest.mock.call(i) for i in range(calls_before_interrupt) + ] + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.latest_checkpoint_path) + _, trainer = get_trainer( + tmp_path, + checkpoint_save_epochs=Slice(start=0, stop=0), + max_epochs=max_epochs, + stepper_state=stepper_state, + ) + trainer.train() + train_data = cast(TrainData, trainer.train_data) + assert train_data.set_epoch_mock.mock_calls == [ + unittest.mock.call(i) for i in range(calls_before_interrupt - 1, max_epochs) + ] + stepper = cast(TrainStepper, trainer.stepper) + assert stepper.loaded_state is not None + assert stepper.loaded_state["foo"] == "bar" + assert "modules" in stepper.loaded_state + assert len(stepper.loaded_state) == 2 + + +@pytest.mark.parametrize("ema_decay", [0.05, 0.99]) +@pytest.mark.parametrize("validate_using_ema", [True, False]) +def test_saves_correct_ema_checkpoints( + tmp_path: str, ema_decay: float, validate_using_ema: bool +): + config, trainer = get_trainer( + tmp_path, + checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all + ema_decay=ema_decay, + validate_using_ema=validate_using_ema, + ) + valid_loss = 0.1 + inference_error = 0.2 + trainer.stepper.modules[0].weight.data.fill_(1.0) + trainer._ema(model=trainer.stepper.modules) + trainer.save_all_checkpoints(valid_loss=valid_loss, inference_error=inference_error) + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.ema_checkpoint_path) + ema_checkpoint = torch.load(paths.ema_checkpoint_path) + ema_weight = 1.0 - min(ema_decay, 2.0 / 11.0) + np.testing.assert_allclose( + ema_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + ema_weight, + atol=1e-7, + ) + assert ema_checkpoint["best_validation_loss"] == valid_loss + assert ema_checkpoint["best_inference_error"] == inference_error + assert os.path.exists(paths.latest_checkpoint_path) + latest_checkpoint = torch.load(paths.latest_checkpoint_path) + np.testing.assert_allclose( + latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + 1.0, + atol=1e-7, + ) + assert latest_checkpoint["best_validation_loss"] == valid_loss + assert latest_checkpoint["best_inference_error"] == inference_error + if validate_using_ema: + best_weight = ema_weight + else: + best_weight = 1.0 + assert os.path.exists(paths.best_checkpoint_path) + best_checkpoint = torch.load(paths.best_checkpoint_path) + assert best_checkpoint["best_validation_loss"] == valid_loss + assert best_checkpoint["best_inference_error"] == inference_error + np.testing.assert_allclose( + best_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + best_weight, + atol=1e-7, + ) + best_inference_checkpoint = torch.load(paths.best_inference_checkpoint_path) + assert best_inference_checkpoint["best_validation_loss"] == valid_loss + assert best_inference_checkpoint["best_inference_error"] == inference_error + np.testing.assert_allclose( + best_inference_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + best_weight, + atol=1e-7, + ) + + +@pytest.mark.parametrize( + "segment_epochs, best_val_epoch, best_inference_epoch", + [ + (None, 1, 1), + (None, 3, 5), + (2, 3, 5), + (2, 5, 3), + (2, 4, 6), + (2, 6, 4), + ], +) +def test_saves_correct_non_ema_epoch_checkpoints( + tmp_path: str, + segment_epochs: Optional[int], + best_val_epoch: int, + best_inference_epoch: int, +): + max_epochs = 10 + if segment_epochs is None: + total_segments = 1 + segment_epochs_value = max_epochs + else: + total_segments = max_epochs // segment_epochs + segment_epochs_value = segment_epochs + train_losses = np.random.rand(max_epochs) + 0.01 + val_losses = np.random.rand(max_epochs) + 0.01 + inference_losses = np.random.rand(max_epochs) + 0.01 + val_losses[best_val_epoch - 1] = 0.0 + inference_losses[best_inference_epoch - 1] = 0.0 + module_values = np.random.rand(max_epochs) + for i in range(total_segments): + config, trainer = get_trainer( + tmp_path, + checkpoint_dir=os.path.join(tmp_path, "checkpoint_dir"), # same dir for all + # for speed, don't save per-epoch checkpoints for this test + checkpoint_save_epochs=Slice(start=0, stop=0), + segment_epochs=segment_epochs, + max_epochs=max_epochs, + train_losses=train_losses[ + i * segment_epochs_value : (i + 1) * segment_epochs_value + ], + validation_losses=val_losses[ + i * segment_epochs_value : (i + 1) * segment_epochs_value + ], + inference_losses=inference_losses[ + i * segment_epochs_value : (i + 1) * segment_epochs_value + ], + stepper_module_values=module_values[ + i * segment_epochs_value : (i + 1) * segment_epochs_value + ], + validate_using_ema=False, + ) + trainer.train() + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.latest_checkpoint_path) + train_data = cast(TrainData, trainer.train_data) + assert train_data.set_epoch_mock.mock_calls == [ + unittest.mock.call(i) + for i in range( + i * segment_epochs_value, + min((i + 1) * segment_epochs_value, config.max_epochs), + ) + ] + latest_checkpoint = torch.load(paths.latest_checkpoint_path) + assert latest_checkpoint["epoch"] == min( + max_epochs, (i + 1) * segment_epochs_value + ) + np.testing.assert_allclose( + latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + module_values[min((i + 1) * segment_epochs_value - 1, max_epochs - 1)], + ) + paths = CheckpointPaths(config.checkpoint_dir) + assert os.path.exists(paths.latest_checkpoint_path) + assert os.path.exists(paths.best_checkpoint_path) + assert os.path.exists(paths.best_inference_checkpoint_path) + assert os.path.exists(paths.ema_checkpoint_path) + best_checkpoint = torch.load(paths.best_checkpoint_path) + assert best_checkpoint["epoch"] == best_val_epoch + assert best_checkpoint["best_validation_loss"] == 0.0 + assert best_checkpoint["best_inference_error"] == np.min( + inference_losses[:best_val_epoch] + ) + np.testing.assert_allclose( + best_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + module_values[best_val_epoch - 1], + ) + best_inference_checkpoint = torch.load(paths.best_inference_checkpoint_path) + assert best_inference_checkpoint["epoch"] == best_inference_epoch + assert best_inference_checkpoint["best_validation_loss"] == np.min( + val_losses[:best_inference_epoch] + ) + assert best_inference_checkpoint["best_inference_error"] == 0.0 + latest_checkpoint = torch.load(paths.latest_checkpoint_path) + assert latest_checkpoint["epoch"] == max_epochs + np.testing.assert_allclose( + latest_checkpoint["stepper"]["modules"]["0.weight"].cpu().numpy(), + module_values[-1], + ) diff --git a/fme/fme/core/generics/train_stepper.py b/fme/fme/core/generics/train_stepper.py new file mode 100644 index 0000000..e9db6c9 --- /dev/null +++ b/fme/fme/core/generics/train_stepper.py @@ -0,0 +1,63 @@ +import abc +from typing import Any, Dict, Generic, Type, TypeVar + +from torch import nn + +from fme.core.generics.inference import PredictFunction +from fme.core.generics.optimization import OptimizationABC +from fme.core.typing_ import TensorDict + +TO = TypeVar("TO", bound="TrainOutputABC") # train output + + +class TrainOutputABC(abc.ABC): + @abc.abstractmethod + def get_metrics(self) -> TensorDict: + pass + + +PS = TypeVar("PS") # prognostic state +BD = TypeVar("BD") # batch data +FD = TypeVar("FD") # forcing data +SD = TypeVar("SD") # stepped data + + +class TrainStepperABC(abc.ABC, Generic[PS, BD, FD, SD, TO]): + SelfType = TypeVar("SelfType", bound="TrainStepperABC") + + @abc.abstractmethod + def train_on_batch( + self, + data: BD, + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> TO: + pass + + @property + @abc.abstractmethod + def modules(self) -> nn.ModuleList: + pass + + @abc.abstractmethod + def get_state(self) -> Dict[str, Any]: + pass + + @abc.abstractmethod + def load_state(self, state: Dict[str, Any]) -> None: + pass + + @classmethod + @abc.abstractmethod + def from_state(cls: Type[SelfType], state: Dict[str, Any]) -> SelfType: + pass + + @property + @abc.abstractmethod + def n_ic_timesteps(self) -> int: + pass + + @property + @abc.abstractmethod + def predict_paired(self) -> PredictFunction[PS, FD, SD]: + pass diff --git a/fme/fme/core/generics/trainer.py b/fme/fme/core/generics/trainer.py new file mode 100644 index 0000000..8e88a13 --- /dev/null +++ b/fme/fme/core/generics/trainer.py @@ -0,0 +1,538 @@ +# This module is derived from the train.py module in the following repository: +# https://github.com/NVlabs/FourCastNet. The corresponding license is +# provided below. + +# BSD 3-Clause License +# +# Copyright (c) 2022, FourCastNet authors +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# The code was authored by the following people: +# +# Jaideep Pathak - NVIDIA Corporation +# Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory +# Peter Harrington - NERSC, Lawrence Berkeley National Laboratory +# Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory +# Ashesh Chattopadhyay - Rice University +# Morteza Mardani - NVIDIA Corporation +# Thorsten Kurth - NVIDIA Corporation +# David Hall - NVIDIA Corporation +# Zongyi Li - California Institute of Technology, NVIDIA Corporation +# Kamyar Azizzadenesheli - Purdue University +# Pedram Hassanzadeh - Rice University +# Karthik Kashinath - NVIDIA Corporation +# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation + +import abc +import contextlib +import gc +import logging +import os +import time +import uuid +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + List, + Optional, + Protocol, + TypeVar, +) + +import torch + +import fme +from fme.core.distributed import Distributed +from fme.core.ema import EMATracker +from fme.core.generics.aggregator import AggregatorABC, InferenceAggregatorABC +from fme.core.generics.data import GriddedDataABC, InferenceDataABC +from fme.core.generics.inference import run_inference +from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC +from fme.core.optimization import NullOptimization, Optimization +from fme.core.timing import GlobalTimer +from fme.core.typing_ import Slice +from fme.core.wandb import WandB + + +class EndOfBatchCallback(Protocol): + def __call__(self) -> None: ... + + +class TrainConfigProtocol(Protocol): + __dataclass_fields__: ClassVar[Dict[str, Any]] + + @property + def experiment_dir(self) -> str: ... + + @property + def checkpoint_dir(self) -> str: ... + + @property + def max_epochs(self) -> int: ... + + @property + def save_checkpoint(self) -> bool: ... + + @property + def validate_using_ema(self) -> bool: ... + + @property + def log_train_every_n_batches(self) -> int: ... + + @property + def segment_epochs(self) -> Optional[int]: ... + + @property + def checkpoint_save_epochs(self) -> Optional[Slice]: ... + + @property + def ema_checkpoint_save_epochs(self) -> Optional[Slice]: ... + + def clean_wandb(self, experiment_dir: str) -> None: ... + + def get_inference_epochs(self) -> List[int]: ... + + +PS = TypeVar("PS", contravariant=True) # prognostic state +TO = TypeVar("TO", bound="TrainOutputABC") # train output +BD = TypeVar("BD") # batch data for training +FD = TypeVar("FD") # forcing data for inference +SD = TypeVar("SD") # stepped data from inference + + +class AggregatorBuilderABC(abc.ABC, Generic[PS, TO, SD]): + @abc.abstractmethod + def get_train_aggregator(self) -> AggregatorABC[TO]: + pass + + @abc.abstractmethod + def get_validation_aggregator(self) -> AggregatorABC[TO]: + pass + + @abc.abstractmethod + def get_inference_aggregator(self) -> InferenceAggregatorABC[PS, SD]: + pass + + +class CheckpointPaths: + def __init__(self, checkpoint_dir: str): + self.checkpoint_dir = checkpoint_dir + + @property + def latest_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "ckpt.tar") + + @property + def best_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "best_ckpt.tar") + + @property + def best_inference_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar") + + @property + def ema_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "ema_ckpt.tar") + + def epoch_checkpoint_path(self, epoch: int) -> str: + return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar") + + def ema_epoch_checkpoint_path(self, epoch: int) -> str: + return os.path.join(self.checkpoint_dir, f"ema_ckpt_{epoch:04d}.tar") + + +class Trainer: + def __init__( + self, + train_data: GriddedDataABC[BD], + validation_data: GriddedDataABC[BD], + inference_data: InferenceDataABC[PS, FD], + stepper: TrainStepperABC[PS, BD, FD, SD, TO], + build_optimization: Callable[[torch.nn.ModuleList], Optimization], + build_ema: Callable[[torch.nn.ModuleList], EMATracker], + config: TrainConfigProtocol, + aggregator_builder: AggregatorBuilderABC[PS, TO, SD], + end_of_batch_callback: EndOfBatchCallback = lambda: None, + do_gc_collect: bool = True, + ): + logging.info(f"Current device is {fme.get_device()}") + self.dist = Distributed.get_instance() + if self.dist.is_root(): + if not os.path.isdir(config.experiment_dir): + os.makedirs(config.experiment_dir) + if not os.path.isdir(config.checkpoint_dir): + os.makedirs(config.checkpoint_dir) + self.config = config + self.paths = CheckpointPaths(config.checkpoint_dir) + + logging.info("rank %d, begin data loader init" % self.dist.rank) + self.train_data = train_data + self.valid_data = validation_data + logging.info("rank %d, data loader initialized" % self.dist.rank) + for gridded_data, name in zip( + (self.train_data, self.valid_data), ("train", "valid") + ): + gridded_data.log_info(name) + + self.num_batches_seen = 0 + self._start_epoch = 0 + self._model_epoch = self._start_epoch + self.num_batches_seen = 0 + self._best_validation_loss = torch.inf + self._best_inference_error = torch.inf + + self.stepper = stepper + self.optimization = build_optimization(stepper.modules) + self._end_of_batch_ops = end_of_batch_callback + self._no_optimization = NullOptimization() + self._aggregator_builder = aggregator_builder + + resuming = os.path.isfile(self.paths.latest_checkpoint_path) + if resuming: + logging.info(f"Resuming training from {self.paths.latest_checkpoint_path}") + self.restore_checkpoint( + self.paths.latest_checkpoint_path, self.paths.ema_checkpoint_path + ) + + wandb = WandB.get_instance() + wandb.watch(self.stepper.modules) + + logging.info( + ( + "Number of trainable model parameters: " + f"{count_parameters(self.stepper.modules)}" + ) + ) + + self._inference_data = inference_data + self._ema = build_ema(stepper.modules) + self._do_gc_collect = do_gc_collect + + def switch_off_grad(self, model: torch.nn.Module): + for param in model.parameters(): + param.requires_grad = False + + def train(self): + logging.info("Starting Training Loop...") + + self._model_epoch = self._start_epoch + inference_epochs = self.config.get_inference_epochs() + if self.config.segment_epochs is None: + segment_max_epochs = self.config.max_epochs + else: + segment_max_epochs = min( + self._start_epoch + self.config.segment_epochs, self.config.max_epochs + ) + # "epoch" describes the loop, self._model_epoch describes model weights + # needed so we can describe the loop even after weights are updated + for epoch in range(self._start_epoch, segment_max_epochs): + if self._do_gc_collect: + # garbage collect to avoid CUDA error in some contexts + # https://github.com/pytorch/pytorch/issues/67978#issuecomment-1661986812 # noqa: E501 + gc.collect() + logging.info(f"Epoch: {epoch+1}") + self.train_data.set_epoch(epoch) + + start_time = time.time() + logging.info(f"Starting training step on epoch {epoch + 1}") + train_logs = self.train_one_epoch() + train_end = time.time() + logging.info(f"Starting validation step on epoch {epoch + 1}") + valid_logs = self.validate_one_epoch() + valid_end = time.time() + if epoch in inference_epochs: + logging.info(f"Starting inference step on epoch {epoch + 1}") + inference_logs = self.inference_one_epoch() + inference_end: Optional[float] = time.time() + else: + inference_logs = {} + inference_end = None + + train_loss = train_logs["train/mean/loss"] + valid_loss = valid_logs["val/mean/loss"] + inference_error = inference_logs.get( + "inference/time_mean_norm/rmse/channel_mean", None + ) + # need to get the learning rate before stepping the scheduler + lr = self.optimization.learning_rate + self.optimization.step_scheduler(valid_loss) + + if self.dist.is_root(): + if self.config.save_checkpoint: + logging.info(f"Saving checkpoints for epoch {epoch + 1}") + self.save_all_checkpoints(valid_loss, inference_error) + + time_elapsed = time.time() - start_time + logging.info(f"Time taken for epoch {epoch + 1} is {time_elapsed} sec") + logging.info(f"Train loss: {train_loss}. Valid loss: {valid_loss}") + if inference_error is not None: + logging.info(f"Inference error: {inference_error}") + + logging.info("Logging to wandb") + all_logs = { + **train_logs, + **valid_logs, + **inference_logs, + **{ + "lr": lr, + "epoch": epoch, + "epoch_train_seconds": train_end - start_time, + "epoch_validation_seconds": valid_end - train_end, + "epoch_total_seconds": time_elapsed, + }, + } + if inference_end is not None: + all_logs["epoch_inference_seconds"] = inference_end - valid_end + wandb = WandB.get_instance() + wandb.log(all_logs, step=self.num_batches_seen) + if segment_max_epochs == self.config.max_epochs: + self.config.clean_wandb(experiment_dir=self.config.experiment_dir) + + def train_one_epoch(self): + """Train for one epoch and return logs from TrainAggregator.""" + wandb = WandB.get_instance() + aggregator = self._aggregator_builder.get_train_aggregator() + n_samples_seen_since_logging = 0 + if self.num_batches_seen == 0: + # Before training, log the loss on the first batch. + with torch.no_grad(), GlobalTimer(): + batch = next(iter(self.train_data.loader)) + stepped = self.stepper.train_on_batch( + batch, + optimization=self._no_optimization, + ) + + if self.config.log_train_every_n_batches > 0: + with torch.no_grad(): + metrics = { + f"batch_{name}": self.dist.reduce_mean(metric) + for name, metric in sorted(stepped.get_metrics().items()) + } + wandb.log(metrics, step=self.num_batches_seen) + current_time = time.time() + for batch in self.train_data.loader: + with GlobalTimer(): + stepped = self.stepper.train_on_batch(batch, self.optimization) + aggregator.record_batch(stepped) + self._end_of_batch_ops() + self._ema(model=self.stepper.modules) + self.num_batches_seen += 1 + n_samples_seen_since_logging += self.train_data.batch_size + if ( + self.config.log_train_every_n_batches > 0 + and self.num_batches_seen % self.config.log_train_every_n_batches == 0 + ): + with torch.no_grad(): + metrics = { + f"batch_{name}": self.dist.reduce_mean(metric) + for name, metric in sorted(stepped.get_metrics().items()) + } + duration = time.time() - current_time + current_time = time.time() + samples_per_second = n_samples_seen_since_logging / duration + metrics["training_samples_per_second_on_rank_0"] = samples_per_second + wandb.log(metrics, step=self.num_batches_seen) + n_samples_seen_since_logging = 0 + self._model_epoch += 1 + + return aggregator.get_logs(label="train") + + @contextlib.contextmanager + def _validation_context(self): + """ + The context for running validation. + + In this context, the stepper uses the EMA model if + `self.config.validate_using_ema` is True. + """ + if self.config.validate_using_ema: + with self._ema_context(): + yield + else: + yield + + @contextlib.contextmanager + def _ema_context(self): + """ + A context where the stepper uses the EMA model. + """ + self._ema.store(parameters=self.stepper.modules.parameters()) + self._ema.copy_to(model=self.stepper.modules) + try: + yield + finally: + self._ema.restore(parameters=self.stepper.modules.parameters()) + + def validate_one_epoch(self): + aggregator = self._aggregator_builder.get_validation_aggregator() + with torch.no_grad(), self._validation_context(), GlobalTimer(): + for batch in self.valid_data.loader: + stepped = self.stepper.train_on_batch( + batch, + optimization=NullOptimization(), + compute_derived_variables=True, + ) + aggregator.record_batch( + batch=stepped, + ) + return aggregator.get_logs(label="val") + + def inference_one_epoch(self): + aggregator = self._aggregator_builder.get_inference_aggregator() + with torch.no_grad(), self._validation_context(), GlobalTimer(): + run_inference( + predict=self.stepper.predict_paired, + data=self._inference_data, + aggregator=aggregator, + ) + logs = aggregator.get_summary_logs() + return {f"inference/{k}": v for k, v in logs.items()} + + def save_checkpoint(self, checkpoint_path): + # save to a temporary file in case we get pre-empted during save + temporary_location = os.path.join( + os.path.dirname(checkpoint_path), f".{uuid.uuid4()}.tmp" + ) + try: + torch.save( + { + "num_batches_seen": self.num_batches_seen, + "epoch": self._model_epoch, + "best_validation_loss": self._best_validation_loss, + "best_inference_error": self._best_inference_error, + "stepper": self.stepper.get_state(), + "optimization": self.optimization.get_state(), + "ema": self._ema.get_state(), + }, + temporary_location, + ) + os.replace(temporary_location, checkpoint_path) + finally: + if os.path.exists(temporary_location): + os.remove(temporary_location) + + def restore_checkpoint(self, checkpoint_path, ema_checkpoint_path): + _restore_checkpoint(self, checkpoint_path, ema_checkpoint_path) + + def _epoch_checkpoint_enabled(self, epoch: int) -> bool: + return epoch_checkpoint_enabled( + epoch, self.config.max_epochs, self.config.checkpoint_save_epochs + ) + + def _ema_epoch_checkpoint_enabled(self, epoch: int) -> bool: + return epoch_checkpoint_enabled( + epoch, self.config.max_epochs, self.config.ema_checkpoint_save_epochs + ) + + def save_all_checkpoints(self, valid_loss: float, inference_error: Optional[float]): + if self.config.validate_using_ema: + best_checkpoint_context = self._ema_context + else: + best_checkpoint_context = contextlib.nullcontext # type: ignore + with best_checkpoint_context(): + save_best_checkpoint = False + if valid_loss <= self._best_validation_loss: + logging.info( + "Saving lowest validation loss checkpoint to " + f"{self.paths.best_checkpoint_path}" + ) + self._best_validation_loss = valid_loss + save_best_checkpoint = True # wait until inference error is updated + if inference_error is not None and ( + inference_error <= self._best_inference_error + ): + logging.info( + f"Epoch inference error ({inference_error}) is lower than " + f"previous best inference error ({self._best_inference_error})." + ) + logging.info( + "Saving lowest inference error checkpoint to " + f"{self.paths.best_inference_checkpoint_path}" + ) + self._best_inference_error = inference_error + self.save_checkpoint(self.paths.best_inference_checkpoint_path) + if save_best_checkpoint: + self.save_checkpoint(self.paths.best_checkpoint_path) + + logging.info(f"Saving latest checkpoint to {self.paths.latest_checkpoint_path}") + self.save_checkpoint(self.paths.latest_checkpoint_path) + with self._ema_context(): + logging.info( + f"Saving latest EMA checkpoint to {self.paths.ema_checkpoint_path}" + ) + self.save_checkpoint(self.paths.ema_checkpoint_path) + if self._epoch_checkpoint_enabled(self._model_epoch): + epoch_checkpoint_path = self.paths.epoch_checkpoint_path(self._model_epoch) + logging.info(f"Saving epoch checkpoint to {epoch_checkpoint_path}") + self.save_checkpoint(epoch_checkpoint_path) + if self._ema_epoch_checkpoint_enabled(self._model_epoch): + ema_epoch_checkpoint_path = self.paths.ema_epoch_checkpoint_path( + self._model_epoch + ) + logging.info(f"Saving EMA epoch checkpoint to {ema_epoch_checkpoint_path}") + with self._ema_context(): + self.save_checkpoint(ema_epoch_checkpoint_path) + + +def _restore_checkpoint(trainer: Trainer, checkpoint_path, ema_checkpoint_path): + # separated into a function only to make it easier to mock + checkpoint = torch.load(checkpoint_path, map_location=fme.get_device()) + # restore checkpoint is used for finetuning as well as resuming. + # If finetuning (i.e., not resuming), restore checkpoint + # does not load optimizer state, instead uses config specified lr. + trainer.stepper.load_state(checkpoint["stepper"]) + trainer.optimization.load_state(checkpoint["optimization"]) + trainer.num_batches_seen = checkpoint["num_batches_seen"] + trainer._start_epoch = checkpoint["epoch"] + trainer._best_validation_loss = checkpoint["best_validation_loss"] + trainer._best_inference_error = checkpoint["best_inference_error"] + ema_checkpoint = torch.load(ema_checkpoint_path, map_location=fme.get_device()) + ema_stepper: TrainStepperABC = type(trainer.stepper).from_state( + ema_checkpoint["stepper"] + ) + trainer._ema = EMATracker.from_state(checkpoint["ema"], ema_stepper.modules) + + +def count_parameters(modules: torch.nn.ModuleList) -> int: + parameters = 0 + for module in modules: + for parameter in module.parameters(): + if parameter.requires_grad: + parameters += parameter.numel() + return parameters + + +def epoch_checkpoint_enabled( + epoch: int, max_epochs: int, save_epochs: Optional[Slice] +) -> bool: + if save_epochs is None: + return False + return epoch in range(max_epochs)[save_epochs.slice] diff --git a/fme/fme/core/generics/writer.py b/fme/fme/core/generics/writer.py new file mode 100644 index 0000000..1692503 --- /dev/null +++ b/fme/fme/core/generics/writer.py @@ -0,0 +1,46 @@ +import abc +from typing import Any, Generic, TypeVar + +PS = TypeVar("PS", contravariant=True) # prognostic state +SD = TypeVar("SD", contravariant=True) # stepped data + + +class WriterABC(abc.ABC, Generic[PS, SD]): + @abc.abstractmethod + def write(self, data: PS, filename: str): + """Eagerly write data to a file at filename.""" + ... + + @abc.abstractmethod + def append_batch( + self, + batch: SD, + ): + """ + Append a batch of data to the output file(s). + + Args: + batch: Data to be written. + """ + ... + + +class NullDataWriter(WriterABC[Any, Any]): + """ + Null pattern for DataWriter, which does nothing. + """ + + def __init__(self): + pass + + def append_batch( + self, + batch: Any, + ): + pass + + def flush(self): + pass + + def write(self, data: Any, filename: str): + pass diff --git a/fme/fme/core/gridded_ops.py b/fme/fme/core/gridded_ops.py new file mode 100644 index 0000000..d2ef53f --- /dev/null +++ b/fme/fme/core/gridded_ops.py @@ -0,0 +1,144 @@ +import abc +from typing import Any, Dict, List, Type, TypeVar + +import torch + +from fme.core import metrics +from fme.core.device import get_device + + +class GriddedOperations(abc.ABC): + @abc.abstractmethod + def area_weighted_mean( + self, data: torch.Tensor, keepdim: bool = False + ) -> torch.Tensor: ... + + def area_weighted_mean_bias( + self, truth: torch.Tensor, predicted: torch.Tensor + ) -> torch.Tensor: + return self.area_weighted_mean(predicted - truth) + + def area_weighted_rmse( + self, truth: torch.Tensor, predicted: torch.Tensor + ) -> torch.Tensor: + return torch.sqrt(self.area_weighted_mean((predicted - truth) ** 2)) + + def area_weighted_std(self, data: torch.Tensor, keepdim: bool = False): + return self.area_weighted_mean( + (data - self.area_weighted_mean(data, keepdim=True)) ** 2, + keepdim=keepdim, + ).sqrt() + + @abc.abstractmethod + def area_weighted_gradient_magnitude_percent_diff( + self, truth: torch.Tensor, predicted: torch.Tensor + ): ... + + def to_state(self) -> Dict[str, Any]: + return { + "type": self.__class__.__name__, + "state": self.get_initialization_kwargs(), + } + + @abc.abstractmethod + def get_initialization_kwargs(self) -> Dict[str, Any]: + """ + Get the keyword arguments needed to initialize the instance. + """ + ... + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "GriddedOperations": + """ + Given a dictionary with a "type" key and a "state" key, return + the GriddedOperations it describes. + + The "type" key should be the name of a subclass of GriddedOperations, + and the "state" key should be a dictionary specific to + that subclass. + + Args: + state: A dictionary with a "type" key and a "state" key. + + Returns: + An instance of the subclass. + """ + if cls is not GriddedOperations: + raise RuntimeError( + "This method should be called on GriddedOperations, " + "not on its subclasses." + ) + subclasses = get_all_subclasses(cls) + for subclass in subclasses: + if subclass.__name__ == state["type"]: + return subclass(**state["state"]) + raise ValueError( + f"Unknown subclass type: {state['type']}, " + f"available: {[s.__name__ for s in subclasses]}" + ) + + +T = TypeVar("T") + + +def get_all_subclasses(cls: Type[T]) -> List[Type[T]]: + """ + Gets all subclasses of a given class, including their subclasses etc. + """ + all_subclasses = [] + for subclass in cls.__subclasses__(): + all_subclasses.append(subclass) + all_subclasses.extend(get_all_subclasses(subclass)) + return all_subclasses + + +class LatLonOperations(GriddedOperations): + HORIZONTAL_DIMS = (-2, -1) + + def __init__(self, area_weights: torch.Tensor): + self._device_area = area_weights.to(get_device()) + self._cpu_area = area_weights.to("cpu") + + def area_weighted_mean( + self, data: torch.Tensor, keepdim: bool = False + ) -> torch.Tensor: + if data.device.type == "cpu": + area_weights = self._cpu_area + else: + area_weights = self._device_area + return metrics.weighted_mean( + data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim + ) + + def area_weighted_gradient_magnitude_percent_diff( + self, truth: torch.Tensor, predicted: torch.Tensor + ): + if predicted.device.type == "cpu": + area_weights = self._cpu_area + else: + area_weights = self._device_area + return metrics.gradient_magnitude_percent_diff( + truth, predicted, weights=area_weights, dim=self.HORIZONTAL_DIMS + ) + + def get_initialization_kwargs(self) -> Dict[str, Any]: + return {"area_weights": self._cpu_area} + + +class HEALPixOperations(GriddedOperations): + HORIZONTAL_DIMS = (-3, -2, -1) + + def area_weighted_mean( + self, data: torch.Tensor, keepdim: bool = False + ) -> torch.Tensor: + return data.mean(dim=self.HORIZONTAL_DIMS, keepdim=keepdim) + + def area_weighted_gradient_magnitude_percent_diff( + self, truth: torch.Tensor, predicted: torch.Tensor + ): + return metrics.gradient_magnitude_percent_diff( + truth, predicted, weights=None, dim=self.HORIZONTAL_DIMS + ) + + def get_initialization_kwargs(self) -> Dict[str, Any]: + return {} diff --git a/fme/fme/core/histogram.py b/fme/fme/core/histogram.py index 8cf4b45..91d9c97 100644 --- a/fme/fme/core/histogram.py +++ b/fme/fme/core/histogram.py @@ -1,7 +1,7 @@ import collections import logging from collections import namedtuple -from typing import Dict, List, Literal, Mapping, Optional, Tuple +from typing import Dict, List, Literal, Mapping, Optional, Tuple, Union import matplotlib.figure import matplotlib.pyplot as plt @@ -128,6 +128,8 @@ def _double_size_left(self): Double the sizes of bins, extending the histogram to the left (further negative). """ + if self.bin_edges is None: + raise RuntimeError("Cannot double size of bins without bin edges") current_range = self.bin_edges[-1] - self.bin_edges[0] new_range = 2 * current_range @@ -147,6 +149,8 @@ def _double_size_right(self): Double the sizes of bins, extending the histogram to the right (further positive). """ + if self.bin_edges is None: + raise RuntimeError("Cannot double size of bins without bin edges") current_range = self.bin_edges[-1] - self.bin_edges[0] new_range = 2 * current_range @@ -167,7 +171,8 @@ def _double_size_right(self): class ComparedDynamicHistograms: """Wrapper of DynamicHistogram for multiple histograms, two histograms per - variable plotted on the same axis.""" + variable plotted on the same axis. + """ def __init__(self, n_bins: int, percentiles: Optional[List[float]] = None) -> None: self.n_bins = n_bins @@ -205,9 +210,9 @@ def _get_histograms( ) -> Dict[str, Dict[Literal["target", "prediction"], _Histogram]]: if self.target_histograms is None or self.prediction_histograms is None: raise ValueError("No data has been added to the histogram") - return_dict: Dict[ - str, Dict[Literal["target", "prediction"], _Histogram] - ] = collections.defaultdict(dict) + return_dict: Dict[str, Dict[Literal["target", "prediction"], _Histogram]] = ( + collections.defaultdict(dict) + ) for k in self.target_histograms: counts, bin_edges = trim_zero_bins( self.target_histograms[k].counts.squeeze(self._time_dim), @@ -250,7 +255,7 @@ def _plot_histogram( return fig def get_wandb(self) -> Dict[str, float]: - return_dict: Dict[str, float] = {} + return_dict: Dict[str, Union[matplotlib.figure.Figure, float]] = {} for field_name, histograms in self._get_histograms().items(): target = histograms.get("target") @@ -309,9 +314,9 @@ def get_dataset(self) -> xr.Dataset: np.zeros_like(target_dataset[missing_prediction_name]), dims=("bin",), ) - prediction_dataset[ - f"{missing_prediction_name}_bin_edges" - ] = target_dataset[f"{missing_prediction_name}_bin_edges"] + prediction_dataset[f"{missing_prediction_name}_bin_edges"] = ( + target_dataset[f"{missing_prediction_name}_bin_edges"] + ) ds = xr.concat([target_dataset, prediction_dataset], dim="source") ds["source"] = ["target", "prediction"] return ds diff --git a/fme/fme/core/logging_utils.py b/fme/fme/core/logging_utils.py index 56bda8d..c33e1e2 100644 --- a/fme/fme/core/logging_utils.py +++ b/fme/fme/core/logging_utils.py @@ -25,7 +25,7 @@ class LoggingConfig: """ Configuration for logging. - Attributes: + Parameters: project: name of the project in Weights & Biases entity: name of the entity in Weights & Biases log_to_screen: whether to log to the screen diff --git a/fme/fme/core/loss.py b/fme/fme/core/loss.py index 3af0fe3..8497f0f 100644 --- a/fme/fme/core/loss.py +++ b/fme/fme/core/loss.py @@ -1,15 +1,12 @@ import dataclasses -from typing import Any, Dict, List, Literal, Mapping, Optional +from typing import Any, Callable, Dict, List, Literal, Mapping, Optional import torch import torch.linalg -from fme.core.data_loading.data_typing import SigmaCoordinates from fme.core.device import get_device from fme.core.packer import Packer -from fme.core.typing_ import TensorDict, TensorMapping - -from .climate_data import ClimateData, compute_dry_air_absolute_differences +from fme.core.typing_ import TensorDict class NaNLoss(torch.nn.Module): @@ -37,27 +34,6 @@ def __call__( return self.loss(predict_tensors, target_tensors) -def get_dry_air_nonconservation( - data: TensorMapping, - area_weights: torch.Tensor, - sigma_coordinates: SigmaCoordinates, -): - """ - Computes the time-average one-step absolute difference in surface pressure due to - changes in globally integrated dry air. - - Args: - data: A mapping from variable name to tensor of shape - [sample, time, lat, lon], in physical units. specific_total_water in kg/kg - and surface_pressure in Pa must be present. - area_weights: The area of each grid cell as a [lat, lon] tensor, in m^2. - sigma_coordinates: The sigma coordinates of the model. - """ - return compute_dry_air_absolute_differences( - ClimateData(data), area=area_weights, sigma_coordinates=sigma_coordinates - ).mean() - - def _construct_weight_tensor( weights: Dict[str, float], out_names: List[str], @@ -76,7 +52,6 @@ def _construct_weight_tensor( n_dim: number of dimensions of the output tensor channel_dim: the channel dimension of the output tensor """ - missing_keys = set(weights.keys()) - set(out_names) if len(missing_keys) > 0: raise KeyError( @@ -120,12 +95,12 @@ def __call__(self, x, y): class AreaWeightedMSELoss(torch.nn.Module): - def __init__(self, area: torch.Tensor): + def __init__(self, area_weighted_mean: Callable[[torch.Tensor], torch.Tensor]): super(AreaWeightedMSELoss, self).__init__() - self._area_weights = area / area.mean() + self._area_weighted_mean = area_weighted_mean def __call__(self, x, y): - return torch.mean((x - y) ** 2 * self._area_weights) + return torch.mean(self._area_weighted_mean((x - y) ** 2)) class WeightedSum(torch.nn.Module): @@ -157,17 +132,21 @@ class GlobalMeanLoss(torch.nn.Module): A module which computes a loss on the global mean of each sample. """ - def __init__(self, area: torch.Tensor, loss: torch.nn.Module): + def __init__( + self, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + loss: torch.nn.Module, + ): """ Args: - area: A tensor of shape (n_lat, n_lon) containing the area of - each grid cell. + area_weighted_mean: Computes an area-weighted mean, removing the + horizontal dimensions. loss: A loss function which takes two tensors of shape (n_samples, n_timesteps, n_channels) and returns a scalar tensor. """ super().__init__() - self.global_mean = GlobalMean(area) + self.global_mean = GlobalMean(area_weighted_mean) self.loss = loss def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -177,21 +156,22 @@ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: class GlobalMean(torch.nn.Module): - def __init__(self, area: torch.Tensor): + def __init__(self, area_weighted_mean: Callable[[torch.Tensor], torch.Tensor]): """ Args: - area: A tensor of shape (n_lat, n_lon) containing the area of - each grid cell. + area_weighted_mean: Computes an area-weighted mean, removing the + horizontal dimensions. """ super().__init__() - self.area_weights = area / area.sum() + self._area_weighted_mean = area_weighted_mean def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x: A tensor of shape (n_samples, n_timesteps, n_channels, n_lat, n_lon) + x: A tensor with spatial dimensions in shape (n_samples, n_timesteps, + n_channels, n_lat, n_lon). """ - return (x * self.area_weights[None, None, None, :, :]).sum(dim=(3, 4)) + return self._area_weighted_mean(x) class VariableWeightingLoss(torch.nn.Module): @@ -200,6 +180,7 @@ def __init__(self, weights: torch.Tensor, loss: torch.nn.Module): Args: weights: A tensor of shape (n_samples, n_channels, n_lat, n_lon) containing the weights to apply to each channel. + loss: A loss function which takes two tensors. """ super().__init__() self.loss = loss @@ -240,11 +221,18 @@ def __post_init__(self): if self.global_mean_type is not None and self.global_mean_type != "LpLoss": raise NotImplementedError(self.global_mean_type) - def build(self, area: torch.Tensor, reduction: Literal["mean", "none"]) -> Any: + def build( + self, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + reduction: Literal["mean", "none"], + ) -> Any: """ Args: - area: A tensor of shape (n_lat, n_lon) containing the area of - each grid cell. + area_weighted_mean: Computes an area-weighted mean, removing the + horizontal dimensions. Only used if the loss function is + AreaWeightedMSE. + reduction: The reduction to apply to the loss, either "mean" or "none". + Only used if the loss function is L1, MSE, or LpLoss. """ if self.type == "LpLoss": main_loss = LpLoss(**self.kwargs) @@ -253,13 +241,14 @@ def build(self, area: torch.Tensor, reduction: Literal["mean", "none"]) -> Any: elif self.type == "MSE": main_loss = torch.nn.MSELoss(reduction=reduction) elif self.type == "AreaWeightedMSE": - main_loss = AreaWeightedMSELoss(area) + main_loss = AreaWeightedMSELoss(area_weighted_mean) elif self.type == "NaN": main_loss = NaNLoss() if self.global_mean_type is not None: global_mean_loss = GlobalMeanLoss( - area=area, loss=LpLoss(**self.global_mean_kwargs) + area_weighted_mean=area_weighted_mean, + loss=LpLoss(**self.global_mean_kwargs), ) final_loss = WeightedSum( modules=[main_loss, global_mean_loss], @@ -312,9 +301,12 @@ def __post_init__(self): ) def build( - self, area: torch.Tensor, out_names: List[str], channel_dim: int = -3 + self, + area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], + out_names: List[str], + channel_dim: int = -3, ) -> Any: - loss = self.loss_config.build(area, reduction="mean") + loss = self.loss_config.build(area_weighted_mean, reduction="mean") weighted_loss = VariableWeightingLoss( weights=_construct_weight_tensor( self.weights, out_names, channel_dim=channel_dim diff --git a/fme/fme/core/masking.py b/fme/fme/core/masking.py new file mode 100644 index 0000000..d173512 --- /dev/null +++ b/fme/fme/core/masking.py @@ -0,0 +1,129 @@ +import dataclasses +from typing import Optional + +import torch + +from fme.core.stacker import Stacker, unstack +from fme.core.typing_ import TensorDict, TensorMapping + + +def replace_on_mask( + original: torch.Tensor, + replacement: torch.Tensor, + mask: torch.Tensor, + mask_value: int, +): + """Replace original with replacement in masked regions. + + Args: + original: The original data tensor. + replacement: The replacement data tensor. + mask: The mask tensor. + mask_value: The value of the mask variable in the region to be replaced. + """ + rounded_mask = torch.round(mask).to(int) + return torch.where( + condition=rounded_mask == mask_value, + input=replacement, + other=original, + ) + + +@dataclasses.dataclass +class MaskingConfig: + """ + Configuration for applying masking to the generated output. + + Parameters: + mask_name: The standard name of the mask. May be a prefix (for a 3D masking + variable) or a full name (for a 2D masking variable). + mask_value: Value of the mask variable in masked regions. Either 0 or 1. + fill_value: The constant fill value to use outside of masked regions. + surface_mask_name: (optional) The full name of the surface mask. Only required + when mask_name is a prefix and separate 2D surface masking is desired. + """ + + mask_name: str + mask_value: int + fill_value: float = 0.0 + surface_mask_name: Optional[str] = None + + def __post_init__(self): + if self.mask_value not in [0, 1]: + raise ValueError( + "mask_value must be either 0 or 1, but got " f"{self.mask_value}" + ) + + def build(self): + return Masking( + mask_name=self.mask_name, + mask_value=self.mask_value, + fill_value=self.fill_value, + surface_mask_name=self.surface_mask_name, + ) + + +class Masking: + """Replace masked regions with a fill value.""" + + def __init__( + self, + mask_name: str, + mask_value: int, + fill_value: float, + surface_mask_name: Optional[str] = None, + ): + self.mask_name = mask_name + self.mask_value = mask_value + self.fill_value = fill_value + self.surface_mask_name = surface_mask_name + mask_map = {self.mask_name: [self.mask_name]} + if self.surface_mask_name is not None: + mask_map[self.surface_mask_name] = [self.surface_mask_name] + + self.mask_stacker = Stacker(mask_map) + + def __call__( + self, + stacker: Stacker, + data: TensorMapping, + mask_data: TensorMapping, + ) -> TensorDict: + """ + Apply masking to the data for standard names recognized by a stacker. + + Args: + stacker: A Stacker for variables to mask in data. + data: The data to mask. + mask_data: The mask data. + + """ + mask = self.mask_stacker(self.mask_name, mask_data) + if self.surface_mask_name is not None: + surface_mask = self.mask_stacker(self.surface_mask_name, mask_data) + else: + surface_mask = None + data_: TensorDict = {**data} + for name in stacker.standard_names: + stacked = stacker(name, data) + if stacked.size(-1) > 1: # 3D masking + mask_ = mask + elif surface_mask is not None: # 2D masking with surface mask + mask_ = surface_mask + elif mask.size(-1) == 1: # 2D masking with mask + mask_ = mask + else: + raise RuntimeError( + "Masking surface_mask_name is None but the input Stacker " + f"includes the 2D standard name {name}." + ) + fill = torch.full_like(stacked, self.fill_value) + masked = replace_on_mask( + original=stacked, + replacement=fill, + mask=mask_, + mask_value=self.mask_value, + ) + level_names = stacker.get_all_level_names(name, data) + data_.update(unstack(masked, level_names, dim=-1)) + return data_ diff --git a/fme/fme/core/metrics.py b/fme/fme/core/metrics.py index 7194ecf..b864c4b 100644 --- a/fme/fme/core/metrics.py +++ b/fme/fme/core/metrics.py @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias from fme.core.constants import GRAVITY +from fme.core.coordinates import HybridSigmaPressureCoordinate Dimension: TypeAlias = Union[int, Iterable[int]] Array: TypeAlias = Union[np.ndarray, torch.Tensor] @@ -114,12 +115,11 @@ def root_mean_squared_error( dim: Dimension = (), ) -> torch.Tensor: """ - Computes the weighted global RMSE over all variables. Namely, for each variable: + Compute a weighted root mean square error between truth and predicted. - sqrt((weights * ((xhat - x) ** 2)).mean(dims)) + Namely: - If you want to compute the RMSE over the time dimension, then pass in - `truth.mean(time_dim)` and `predicted.mean(time_dim)` and specify `dims=space_dims`. + sqrt((weights * ((xhat - x) ** 2)).mean(dims)) Args: truth: torch.Tensor whose last dimensions are to be weighted @@ -128,7 +128,7 @@ def root_mean_squared_error( dim: Dimensions to average over. Returns: - a tensor of shape (variable,) of weighted RMSEs. + A tensor of weighted RMSEs. """ assert ( truth.shape == predicted.shape @@ -157,7 +157,8 @@ def gradient_magnitude_percent_diff( dim: Dimension = (), ) -> torch.Tensor: """Compute the percent difference of the weighted mean gradient magnitude across - the specified dimensions.""" + the specified dimensions. + """ truth_grad_mag = weighted_mean_gradient_magnitude(truth, weights, dim) predicted_grad_mag = weighted_mean_gradient_magnitude(predicted, weights, dim) return 100 * (predicted_grad_mag - truth_grad_mag) / truth_grad_mag @@ -219,74 +220,23 @@ def time_and_global_mean_bias( return result -def vertical_integral( - integrand: torch.Tensor, - surface_pressure: torch.Tensor, - sigma_grid_offsets_ak: torch.Tensor, - sigma_grid_offsets_bk: torch.Tensor, -) -> torch.Tensor: - """Computes a vertical integral, namely: - - (1 / g) * ∫ x dp - - where - - g = acceleration due to gravity - - x = integrad - - p = pressure level - - Args: - integrand (lat, lon, vertical_level), (kg/kg) - surface_pressure: (lat, lon), (Pa) - sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) - sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) - - Returns: - Vertical integral of the integrand (lat, lon). - """ - ak, bk = sigma_grid_offsets_ak, sigma_grid_offsets_bk - pressure_thickness = ((ak + (surface_pressure.unsqueeze(-1) * bk))).diff( - dim=-1 - ) # Pa - integral = torch.sum(pressure_thickness * integrand, axis=-1) # type: ignore - return 1 / GRAVITY * integral - - def surface_pressure_due_to_dry_air( specific_total_water: torch.Tensor, surface_pressure: torch.Tensor, - sigma_grid_offsets_ak: torch.Tensor, - sigma_grid_offsets_bk: torch.Tensor, + vertical_coordinate: HybridSigmaPressureCoordinate, ) -> torch.Tensor: """Computes the dry air (Pa). Args: - specific_total_water (lat, lon, vertical_level), (kg/kg) - surface_pressure: (lat, lon), (Pa) - sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) - sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) + specific_total_water: last dimension is vertical level (kg/kg) + surface_pressure: the surface preessure in Pa. + vertical_coordinate: the vertical coordinate for computing vertical integral. Returns: - Vertically integrated dry air (lat, lon) (Pa) + The surface pressure due to dry air mass only. (Pa) """ - - num_levels = len(sigma_grid_offsets_ak) - 1 - - if ( - num_levels != len(sigma_grid_offsets_bk) - 1 - or num_levels != specific_total_water.shape[-1] - ): - raise ValueError( - ( - "Number of vertical levels in ak, bk, and specific_total_water must" - "be the same." - ) - ) - - total_water_path = vertical_integral( - specific_total_water, - surface_pressure, - sigma_grid_offsets_ak, - sigma_grid_offsets_bk, + total_water_path = vertical_coordinate.vertical_integral( + specific_total_water, surface_pressure ) dry_air = surface_pressure - GRAVITY * total_water_path return dry_air diff --git a/fme/fme/core/normalizer.py b/fme/fme/core/normalizer.py index 6bc3855..49cebc8 100644 --- a/fme/fme/core/normalizer.py +++ b/fme/fme/core/normalizer.py @@ -6,8 +6,8 @@ import torch import torch.jit -from fme.core.device import get_device -from fme.core.typing_ import TensorDict +from fme.core.device import move_tensordict_to_device +from fme.core.typing_ import TensorDict, TensorMapping @dataclasses.dataclass @@ -18,17 +18,15 @@ class NormalizationConfig: Either global_means_path and global_stds_path or explicit means and stds must be provided. - Attributes: + Parameters: global_means_path: Path to a netCDF file containing global means. global_stds_path: Path to a netCDF file containing global stds. - exclude_names: Names to exclude from normalization. means: Mapping from variable names to means. stds: Mapping from variable names to stds. """ global_means_path: Optional[str] = None global_stds_path: Optional[str] = None - exclude_names: Optional[List[str]] = None means: Mapping[str, float] = dataclasses.field(default_factory=dict) stds: Mapping[str, float] = dataclasses.field(default_factory=dict) @@ -49,8 +47,6 @@ def __post_init__(self): ) def build(self, names: List[str]): - if self.exclude_names is not None: - names = list(set(names) - set(self.exclude_names)) using_path = ( self.global_means_path is not None and self.global_stds_path is not None ) @@ -66,24 +62,6 @@ def build(self, names: List[str]): return StandardNormalizer(means=means, stds=stds) -@dataclasses.dataclass -class FromStateNormalizer: - """ - An alternative to NormalizationConfig which provides a normalizer - initialized from a serializable state. This is not a public configuration - class, but instead allows for loading trained models that have been - serialized to disk, using the pre-existing normalization state. - - Attributes: - state: State dict of a normalizer. - """ - - state: Dict[str, Dict[str, float]] - - def build(self, names: List[str]): - return StandardNormalizer.from_state(self.state) - - class StandardNormalizer: """ Responsible for normalizing tensors. @@ -94,14 +72,17 @@ def __init__( means: TensorDict, stds: TensorDict, ): - self.means = means - self.stds = stds + self.means = move_tensordict_to_device(means) + self.stds = move_tensordict_to_device(stds) + self._names = set(means).intersection(stds) - def normalize(self, tensors: TensorDict) -> TensorDict: - return _normalize(tensors, means=self.means, stds=self.stds) + def normalize(self, tensors: TensorMapping) -> TensorDict: + filtered_tensors = {k: v for k, v in tensors.items() if k in self._names} + return _normalize(filtered_tensors, means=self.means, stds=self.stds) - def denormalize(self, tensors: TensorDict) -> TensorDict: - return _denormalize(tensors, means=self.means, stds=self.stds) + def denormalize(self, tensors: TensorMapping) -> TensorDict: + filtered_tensors = {k: v for k, v in tensors.items() if k in self._names} + return _denormalize(filtered_tensors, means=self.means, stds=self.stds) def get_state(self): """ @@ -113,19 +94,15 @@ def get_state(self): } @classmethod - def from_state(self, state) -> "StandardNormalizer": + def from_state(cls, state) -> "StandardNormalizer": """ Loads state from a serializable data structure. """ means = { - k: torch.tensor(v, device=get_device(), dtype=torch.float) - for k, v in state["means"].items() - } - stds = { - k: torch.tensor(v, device=get_device(), dtype=torch.float) - for k, v in state["stds"].items() + k: torch.tensor(v, dtype=torch.float) for k, v in state["means"].items() } - return StandardNormalizer(means=means, stds=stds) + stds = {k: torch.tensor(v, dtype=torch.float) for k, v in state["stds"].items()} + return cls(means=means, stds=stds) @torch.jit.script @@ -134,10 +111,7 @@ def _normalize( means: TensorDict, stds: TensorDict, ) -> TensorDict: - return { - k: (t - means[k]) / stds[k] if k in means.keys() else t - for k, t in tensors.items() - } + return {k: (t - means[k]) / stds[k] for k, t in tensors.items()} @torch.jit.script @@ -146,10 +120,7 @@ def _denormalize( means: TensorDict, stds: TensorDict, ) -> TensorDict: - return { - k: t * stds[k] + means[k] if k in means.keys() else t - for k, t in tensors.items() - } + return {k: t * stds[k] + means[k] for k, t in tensors.items()} def get_normalizer( diff --git a/fme/fme/core/ocean.py b/fme/fme/core/ocean.py index 21cdf39..67b697c 100644 --- a/fme/fme/core/ocean.py +++ b/fme/fme/core/ocean.py @@ -16,7 +16,7 @@ class SlabOceanConfig: """ Configuration for a slab ocean model. - Attributes: + Parameters: mixed_layer_depth_name: Name of the mixed layer depth field. q_flux_name: Name of the heat flux field. """ @@ -34,7 +34,7 @@ class OceanConfig: """ Configuration for determining sea surface temperature from an ocean model. - Attributes: + Parameters: surface_temperature_name: Name of the sea surface temperature field. ocean_fraction_name: Name of the ocean fraction field. interpolate: If True, interpolate between ML-predicted surface temperature and @@ -102,16 +102,16 @@ def __init__(self, config: OceanConfig, timestep: datetime.timedelta): def __call__( self, - target_data: TensorMapping, input_data: TensorMapping, gen_data: TensorMapping, + target_data: TensorMapping, ) -> TensorDict: """ Args: - target_data: Denormalized data that includes mask and forcing data. Assumed - to correspond to the same time step as gen_data. input_data: Denormalized input data for current step. gen_data: Denormalized output data for current step. + target_data: Denormalized data that includes mask and forcing data. Assumed + to correspond to the same time step as gen_data. Returns: gen_data with sea surface temperature overwritten by ocean model. diff --git a/fme/fme/core/optimization.py b/fme/fme/core/optimization.py index 1638be5..2fdb1ca 100644 --- a/fme/fme/core/optimization.py +++ b/fme/fme/core/optimization.py @@ -1,18 +1,20 @@ import contextlib import dataclasses -from typing import Any, Literal, Mapping, Optional +import itertools +from typing import Any, Iterable, Literal, Mapping, Optional import torch import torch.cuda.amp as amp from torch import nn +from fme.core.generics.optimization import OptimizationABC from fme.core.scheduler import SchedulerConfig -class Optimization: +class Optimization(OptimizationABC): def __init__( self, - parameters, + parameters: Iterable[torch.nn.Parameter], optimizer_type: Literal["Adam", "FusedAdam"], lr: float, max_epochs: int, @@ -114,7 +116,7 @@ class OptimizationConfig: """ Configuration for optimization. - Attributes: + Parameters: optimizer_type: The type of optimizer to use. lr: The learning rate. kwargs: Additional keyword arguments to pass to the optimizer. @@ -132,7 +134,8 @@ class OptimizationConfig: default_factory=lambda: SchedulerConfig() ) - def build(self, parameters, max_epochs: int) -> Optimization: + def build(self, modules: torch.nn.ModuleList, max_epochs: int) -> Optimization: + parameters = itertools.chain(*[module.parameters() for module in modules]) return Optimization( parameters=parameters, optimizer_type=self.optimizer_type, @@ -151,7 +154,7 @@ def from_state(cls, state: Mapping[str, Any]) -> "OptimizationConfig": return cls(**state) -class NullOptimization: +class NullOptimization(OptimizationABC): @contextlib.contextmanager def autocast(self): yield diff --git a/fme/fme/core/packer.py b/fme/fme/core/packer.py index 362591e..9d10dca 100644 --- a/fme/fme/core/packer.py +++ b/fme/fme/core/packer.py @@ -22,7 +22,7 @@ def __init__(self, names: List[str]): def pack(self, tensors: TensorDict, axis=0) -> torch.Tensor: """ - Packs tensors into a single tensor, concatenated along a new axis + Packs tensors into a single tensor, concatenated along a new axis. Args: tensors: Dict from names to tensors. diff --git a/fme/fme/core/parameter_init.py b/fme/fme/core/parameter_init.py index d49649d..d3fdcd7 100644 --- a/fme/fme/core/parameter_init.py +++ b/fme/fme/core/parameter_init.py @@ -23,7 +23,7 @@ class FrozenParameterConfig: An exception is raised if a parameter is included by both lists. - Attributes: + Parameters: include: list of parameter names to freeze (set requires_grad = False) exclude: list of parameter names to ignore """ @@ -70,7 +70,7 @@ class ParameterInitializationConfig: pre-trained model. If the built model has larger weights than the pre-trained model, only the initial slice of the weights is overwritten. - Attributes: + Parameters: weight_path: path to a SingleModuleStepper checkpoint containing weights to load exclude_parameters: list of parameter names to exclude from the loaded @@ -123,6 +123,8 @@ def apply( def regularizer(): return torch.tensor(0.0, device=device) + return module, regularizer + else: loaded_state_dict = { name: value.to(device) for name, value in loaded_state_dict.items() @@ -136,6 +138,8 @@ def regularizer(): "which is not allowed" ) + non_optional_state_dict = loaded_state_dict + def regularizer(): loss = torch.tensor(0.0, device=device) for name in from_names: @@ -155,7 +159,7 @@ def regularizer(): self.alpha / 2 * torch.linalg.norm( - (param - loaded_state_dict[name]).flatten(), + (param - non_optional_state_dict[name]).flatten(), ord=2, ) ) diff --git a/fme/fme/core/prescriber.py b/fme/fme/core/prescriber.py index 3a813a2..d0b1f64 100644 --- a/fme/fme/core/prescriber.py +++ b/fme/fme/core/prescriber.py @@ -1,8 +1,7 @@ import dataclasses from typing import List -import torch - +from fme.core.masking import replace_on_mask from fme.core.typing_ import TensorDict, TensorMapping @@ -17,7 +16,7 @@ class PrescriberConfig: target value at 1 based on the mask variable, and it is assumed the mask variable lies in the range from 0 to 1. - Attributes: + Parameters: prescribed_name: Name of the variable to be overwritten. mask_name: Name of the mask variable. mask_value: Value of the mask variable in the region to be overwritten. @@ -78,7 +77,7 @@ def __call__( ) -> TensorDict: """ Args: - data: Dictionary of data containing the mask variable. + mask_data: Dictionary of data containing the mask variable. gen: Dictionary of data to use outside of mask region. target: Dictionary of data to use in mask region. @@ -104,11 +103,11 @@ def __call__( ) else: # overwrite specified target variable in given mask region - rounded_mask = torch.round(mask_data[self.mask_name]).to(int) - output = torch.where( - condition=rounded_mask == self.mask_value, - input=target[self.prescribed_name], - other=gen[self.prescribed_name], + output = replace_on_mask( + original=gen[self.prescribed_name], + replacement=target[self.prescribed_name], + mask=mask_data[self.mask_name], + mask_value=self.mask_value, ) return {**gen, self.prescribed_name: output} diff --git a/fme/fme/core/registry/__init__.py b/fme/fme/core/registry/__init__.py new file mode 100644 index 0000000..867272f --- /dev/null +++ b/fme/fme/core/registry/__init__.py @@ -0,0 +1,2 @@ +from .corrector import CorrectorSelector +from .module import ModuleSelector diff --git a/fme/fme/core/registry/corrector.py b/fme/fme/core/registry/corrector.py new file mode 100644 index 0000000..a2ac927 --- /dev/null +++ b/fme/fme/core/registry/corrector.py @@ -0,0 +1,90 @@ +import dataclasses +import datetime +from typing import Any, Callable, ClassVar, Mapping, Type, TypeVar + +import dacite + +from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.corrector.registry import CorrectorConfigProtocol +from fme.core.gridded_ops import GriddedOperations + +from .registry import Registry + +CT = TypeVar("CT", bound=Type[CorrectorConfigProtocol]) + + +@dataclasses.dataclass +class CorrectorSelector: + """ + A dataclass containing all the information needed to build a + CorrectorConfigProtocol, including the type of the CorrectorConfigProtocol + and the data needed to build it. + + This is helpful as CorrectorSelector can be serialized and deserialized + without any additional information, whereas to load a + CorrectorConfigProtocol you would need to know the type of the + CorrectorConfigProtocol being loaded. + + It is also convenient because CorrectorSelector is a single class that can + be used to represent any CorrectorConfigProtocol, whereas + CorrectorConfigProtocol is a protocol that can be implemented by many + different classes. + + Parameters: + type: the type of the CorrectorConfigProtocol + config: data for a CorrectorConfigProtocol instance of the indicated type + + """ + + type: str + config: Mapping[str, Any] + registry: ClassVar[Registry] = Registry() + + def __post__init(self): + if self.registry is not Registry(): + raise ValueError("CorrectorSelector.registry should not be set manually") + + @classmethod + def register(cls, type_name) -> Callable[[CT], CT]: + return cls.registry.register(type_name) + + def build( + self, + gridded_operations: GriddedOperations, + vertical_coordinate: HybridSigmaPressureCoordinate, + timestep: datetime.timedelta, + ): + instance = self.registry.from_dict(self.get_state()) + return instance.build( + gridded_operations=gridded_operations, + vertical_coordinate=vertical_coordinate, + timestep=timestep, + ) + + def get_state(self) -> Mapping[str, Any]: + """ + Get a dictionary containing all the information needed to build a + CorrectorConfigProtocol. + + """ + return {"type": self.type, "config": self.config} + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "CorrectorSelector": + """ + Create a CorrectorSelector from a dictionary containing all the information + needed to build a CorrectorConfigProtocol. + """ + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + @classmethod + def from_dict(cls, config: dict): + instance = cls.registry.from_dict(config) + return cls(config=instance, type=config["type"]) + + @classmethod + def get_available_types(cls): + """This class method is used to expose all available types of Correctors.""" + return cls(type="", config={}).registry._types.keys() diff --git a/fme/fme/core/registry.py b/fme/fme/core/registry/module.py similarity index 66% rename from fme/fme/core/registry.py rename to fme/fme/core/registry/module.py index ffe6ec7..29afade 100644 --- a/fme/fme/core/registry.py +++ b/fme/fme/core/registry/module.py @@ -1,10 +1,12 @@ import abc import dataclasses -from typing import Any, Callable, Dict, Mapping, Tuple, Type +from typing import Any, Callable, ClassVar, Mapping, Tuple, TypeVar, Union import dacite from torch import nn +from .registry import Registry + @dataclasses.dataclass class ModuleConfig(abc.ABC): @@ -30,8 +32,8 @@ def build( Args: n_in_channels: number of input channels n_out_channels: number of output channels - img_shape: last two dimensions of data, corresponding to lat and - lon when using FourCastNet conventions + img_shape: shape of last two dimensions of data, e.g. latitude and + longitude. Returns: a nn.Module @@ -49,41 +51,7 @@ def from_state(cls, state: Mapping[str, Any]) -> "ModuleConfig": ) -NET_REGISTRY: Dict[str, Type[ModuleConfig]] = {} - - -def get_available_module_types(): - return NET_REGISTRY.keys() - - -def register(name: str) -> Callable[[Type[ModuleConfig]], Type[ModuleConfig]]: - """ - Register a new ModuleConfig type with the NET_REGISTRY. - - This is useful for adding new ModuleConfig types to the registry from - other modules. - - Args: - name: name of the ModuleConfig type to register - - Returns: - a decorator which registers the decorated class with the NET_REGISTRY - """ - if not isinstance(name, str): - raise TypeError( - f"name must be a string, got {name}, " - "make sure to use as @register('module_name')" - ) - - def decorator(cls: Type[ModuleConfig]) -> Type[ModuleConfig]: - NET_REGISTRY[name] = cls - return cls - - return decorator - - -def get_from_registry(name) -> Type[ModuleConfig]: - return NET_REGISTRY[name] +MT = TypeVar("MT", bound=nn.Module) @dataclasses.dataclass @@ -100,22 +68,22 @@ class ModuleSelector: used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes. - Attributes: + Parameters: type: the type of the ModuleConfig config: data for a ModuleConfig instance of the indicated type """ type: str config: Mapping[str, Any] + registry: ClassVar[Registry] = Registry() + + def __post__init(self): + if self.registry is not Registry(): + raise ValueError("ModuleSelector.registry should not be set manually") - def __post_init__(self): - try: - self._config = get_from_registry(self.type).from_state(self.config) - except KeyError: - raise ValueError( - f"unknown module type {self.type}, " - f"known module types are {list(NET_REGISTRY.keys())}" - ) + @classmethod + def register(cls, type_name) -> Callable[[MT], MT]: + return cls.registry.register(type_name) def build( self, @@ -136,13 +104,14 @@ def build( Returns: a nn.Module """ - return self._config.build( + instance = self.registry.from_dict(self.get_state()) + return instance.build( n_in_channels=n_in_channels, n_out_channels=n_out_channels, img_shape=img_shape, ) - def get_state(self) -> Mapping[str, Any]: + def get_state(self) -> Union[Mapping[str, Any], dict]: """ Get a dictionary containing all the information needed to build a ModuleConfig. """ @@ -155,5 +124,15 @@ def from_state(cls, state: Mapping[str, Any]) -> "ModuleSelector": needed to build a ModuleConfig. """ return dacite.from_dict( - data_class=ModuleSelector, data=state, config=dacite.Config(strict=True) + data_class=cls, data=state, config=dacite.Config(strict=True) ) + + @classmethod + def from_dict(cls, config: dict): + instance = cls.registry.from_dict(config) + return cls(config=instance, type=config["type"]) + + @classmethod + def get_available_types(cls): + """This class method is used to expose all available types of Modules.""" + return cls(type="", config={}).registry._types.keys() diff --git a/fme/fme/core/registry/registry.py b/fme/fme/core/registry/registry.py new file mode 100644 index 0000000..0c8d5a7 --- /dev/null +++ b/fme/fme/core/registry/registry.py @@ -0,0 +1,73 @@ +from typing import Any, Callable, Dict, Generic, Mapping, Optional, Type, TypeVar + +import dacite + +T = TypeVar("T") +TT = TypeVar("TT", bound=Type) + + +class Registry(Generic[T]): + """ + Used to register and initialize multiple types of a dataclass. + """ + + def __init__(self, default_type: Optional[str] = None): + """ + Initialize the registry. + + Args: + default_type: if given, the "type" key in the config dict is optional + and by default this type will be used. + """ + self._types: Dict[str, Type[T]] = {} + self.default_type = default_type + + def register(self, type_name: str) -> Callable[[TT], TT]: + """ + Registers a configuration type with the registry. + + When registry.from_dict is called to initialize a dataclass, if the + "type" key in that dictionary is equal to the type_name you give here, + then the decorated class will be the one initialized from the data + in the "config" key. + + Args: + type_name: name used in configuration to indicate the decorated + class as the target type to be initialized when using from_dict. + """ + + def register_func(cls: TT) -> TT: + self._types[type_name] = cls + return cls + + return register_func + + def from_dict(self, config: Mapping[str, Any]) -> T: + """ + Creates a registered type from the given config dict. + + Config should have at least one key, "type", which indicates the type to + initialize based on its registered type name. This can be omitted if + this instance was initialized with a default type. + + It can also have a "config" key, which is a dict used to initialize the + dataclass. By default this is an empty dict. + """ + config = dict(config) + config.setdefault("config", {}) + if self.default_type is not None: + type_name = config.get("type", self.default_type) + else: + type_name = config["type"] + if type_name not in self._types: + raise ValueError( + f"Received unexpected type {type_name}, " + f"expected one of {self._types.keys()}" + ) + else: + instance = dacite.from_dict( + data_class=self._types[type_name], + data=config["config"], + config=dacite.Config(strict=True), + ) + return instance diff --git a/fme/fme/core/registry/test_module_registry.py b/fme/fme/core/registry/test_module_registry.py new file mode 100644 index 0000000..6dba17d --- /dev/null +++ b/fme/fme/core/registry/test_module_registry.py @@ -0,0 +1,38 @@ +import dataclasses +from typing import Iterable, List, Tuple + +import torch + +from .module import ModuleConfig, ModuleSelector + + +class MockModule(torch.nn.Module): + def __init__(self, param_shapes: Iterable[Tuple[int, ...]]): + super().__init__() + for i, shape in enumerate(param_shapes): + setattr(self, f"param{i}", torch.nn.Parameter(torch.randn(shape))) + + +@ModuleSelector.register("mock") +@dataclasses.dataclass +class MockModuleBuilder(ModuleConfig): + param_shapes: List[Tuple[int, ...]] + + def build(self, n_in_channels, n_out_channels, img_shape): + return MockModule(self.param_shapes) + + @classmethod + def from_state(cls, state): + return cls(state["param_shapes"]) + + def get_state(self): + return { + "param_shapes": self.param_shapes, + } + + +def test_register(): + """Make sure that the registry is working as expected.""" + selector = ModuleSelector(type="mock", config={"param_shapes": [(1, 2, 3)]}) + module = selector.build(1, 1, (16, 32)) + assert isinstance(module, MockModule) diff --git a/fme/fme/core/regrid.py b/fme/fme/core/regrid.py new file mode 100644 index 0000000..e69de29 diff --git a/fme/fme/core/scheduler.py b/fme/fme/core/scheduler.py index 8227af3..cc771ba 100644 --- a/fme/fme/core/scheduler.py +++ b/fme/fme/core/scheduler.py @@ -9,7 +9,7 @@ class SchedulerConfig: """ Configuration for a scheduler to use during training. - Attributes: + Parameters: type: Name of scheduler class from torch.optim.lr_scheduler, no scheduler is used by default. kwargs: Keyword arguments to pass to the scheduler constructor. diff --git a/fme/fme/core/stacker.py b/fme/fme/core/stacker.py new file mode 100644 index 0000000..707c662 --- /dev/null +++ b/fme/fme/core/stacker.py @@ -0,0 +1,151 @@ +import re +from typing import List, Mapping, Union + +import torch + +from fme.core.typing_ import TensorMapping + + +def natural_sort(alist: List[str]) -> List[str]: + """Sort to alphabetical order but with numbers sorted + numerically, e.g. a11 comes after a2. See [1] and [2]. + + [1] https://stackoverflow.com/questions/11150239/natural-sorting + [2] https://en.wikipedia.org/wiki/Natural_sort_order + """ + + def convert(text: str) -> Union[str, int]: + if text.isdigit(): + return int(text) + else: + return text.lower() + + def alphanum_key(item: str) -> List[Union[str, int]]: + return [convert(c) for c in re.split("([0-9]+)", item)] + + return sorted(alist, key=alphanum_key) + + +def unstack(tensor: torch.Tensor, names: List[str], dim: int = -1) -> TensorMapping: + """Unstack a 3D variable to a dictionary of 2D variables. + + Args: + tensor: 3D tensor to unstack, such as output by a Stacker. + names: List of names in natural order to assign to the unstacked variables. + Stacker.get_all_level_names can help in retrieving these names when the + input tensor was stacked via a Stacker. + dim: Dimension along which to unstack. + + """ + if len(names) != tensor.size(dim): + raise ValueError( + f"Received {len(names)} names, but 3D tensor has " + f"{tensor.size(-1)} levels." + ) + if len(names) == 1: + return {names[0]: tensor.select(dim=dim, index=0)} + # split the output tensor along the vertical dimension + tensors = torch.split(tensor, 1, dim=dim) + return {name: tensor.squeeze(dim=dim) for name, tensor in zip(names, tensors)} + + +class Stacker: + """Handles extraction and stacking of data tensors for 3D variables.""" + + LEVEL_PATTERN = re.compile(r"_(\d+)$") + + def __init__( + self, + prefix_map: Mapping[str, List[str]], + ): + """ + Args: + prefix_map: Mapping which defines the correspondence between an arbitrary + set of "standard" names (e.g., "surface_pressure" or "air_temperature") + and lists of possible names or prefix variants (e.g., ["PRESsfc", "PS"] + or ["air_temperature_", "T_"]) found in the data. + """ + self._prefix_map = prefix_map + + @property + def prefix_map(self) -> Mapping[str, List[str]]: + """Mapping which defines the correspondence between an arbitrary set of + "standard" names (e.g., "surface_pressure" or "air_temperature") and + lists of possible names or prefix variants (e.g., ["PRESsfc", "PS"] or + ["air_temperature_", "T_"]) found in the data. + """ + return self._prefix_map + + @property + def standard_names(self) -> List[str]: + return list(self._prefix_map.keys()) + + def get_all_level_names(self, standard_name: str, data: TensorMapping) -> List[str]: + """Get the names of all variables in the data that match one of the + prefixes associated with the given standard name. If the standard name + corresponds to a 3D variable, returns all vertical level names in their + natural order. + """ + if standard_name not in self.standard_names: + raise ValueError(f"{standard_name} is not a standard name.") + for prefix_or_name in self._prefix_map[standard_name]: + if prefix_or_name in data: + return [prefix_or_name] + try: + return self._natural_sort_names(prefix_or_name, data) + except KeyError: + pass + raise KeyError( + f"No prefix associated with {standard_name} was found in data keys." + ) + + def __call__(self, standard_name: str, data: TensorMapping) -> torch.Tensor: + """Extract the variable corresponding to standard name and return as a + 3D tensor. + + """ + return self._stack_levels_try(standard_name, data) + + def _stack_levels(self, prefix_or_name: str, data: TensorMapping) -> torch.Tensor: + names = self._natural_sort_names(prefix_or_name, data) + # stack along the final dimension + return torch.stack([data[name] for name in names], dim=-1) + + def _stack_levels_try( + self, standard_name: str, data: TensorMapping + ) -> torch.Tensor: + prefixes_or_names = self._prefix_map[standard_name] + for prefix_or_name in prefixes_or_names: + if prefix_or_name in data: + # 2D variable, return as 1-level 3D tensor + return data[prefix_or_name].unsqueeze(-1) + try: + return self._stack_levels(prefix_or_name, data) + except KeyError: + pass + raise KeyError( + f"Found no matches for any of {prefixes_or_names} " + f"among the data names {list(data.keys())}." + ) + + def _natural_sort_names(self, prefix: str, data: TensorMapping) -> List[str]: + names = [field_name for field_name in data if field_name.startswith(prefix)] + + levels = [] + for name in names: + match = self.LEVEL_PATTERN.search(name) + if match is None: + raise ValueError( + f"Invalid field name {name}, is a prefix variable " + "but does not end in _{number}." + ) + levels.append(int(match.group(1))) + + for i, level in enumerate(sorted(levels)): + if i != level: + raise ValueError(f"Missing level {i} in {prefix} levels {levels}.") + + if len(names) == 0: + raise KeyError(prefix) + + return natural_sort(names) diff --git a/fme/fme/core/stepper.py b/fme/fme/core/stepper.py deleted file mode 100644 index 9cea13b..0000000 --- a/fme/fme/core/stepper.py +++ /dev/null @@ -1,677 +0,0 @@ -import dataclasses -import datetime -from copy import copy -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union - -import dacite -import torch -from torch import nn - -from fme.core.corrector import CorrectorConfig -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.data_loading.requirements import DataRequirements -from fme.core.data_loading.utils import decode_timestep, encode_timestep -from fme.core.device import get_device -from fme.core.distributed import Distributed -from fme.core.loss import WeightedMappingLossConfig -from fme.core.normalizer import ( - FromStateNormalizer, - NormalizationConfig, - StandardNormalizer, -) -from fme.core.ocean import Ocean, OceanConfig -from fme.core.packer import Packer -from fme.core.registry import ModuleSelector - -from .optimization import NullOptimization, Optimization -from .parameter_init import ParameterInitializationConfig -from .typing_ import TensorDict, TensorMapping - -DEFAULT_TIMESTEP = datetime.timedelta(hours=6) -DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP) - - -@dataclasses.dataclass -class SingleModuleStepperConfig: - """ - Configuration for a single module stepper. - - Attributes: - builder: The module builder. - in_names: Names of input variables. - out_names: Names of output variables. - normalization: The normalization configuration. - parameter_init: The parameter initialization configuration. - ocean: The ocean configuration. - loss: The loss configuration. - corrector: The corrector configuration. - next_step_forcing_names: Names of forcing variables for the next timestep. - loss_normalization: The normalization configuration for the loss. - residual_normalization: Optional alternative to configure loss normalization. - If provided, it will be used for all *prognostic* variables in loss scaling. - """ - - builder: ModuleSelector - in_names: List[str] - out_names: List[str] - normalization: Union[NormalizationConfig, FromStateNormalizer] - parameter_init: ParameterInitializationConfig = dataclasses.field( - default_factory=lambda: ParameterInitializationConfig() - ) - ocean: Optional[OceanConfig] = None - loss: WeightedMappingLossConfig = dataclasses.field( - default_factory=lambda: WeightedMappingLossConfig() - ) - corrector: CorrectorConfig = dataclasses.field( - default_factory=lambda: CorrectorConfig() - ) - next_step_forcing_names: List[str] = dataclasses.field(default_factory=list) - loss_normalization: Optional[Union[NormalizationConfig, FromStateNormalizer]] = None - residual_normalization: Optional[ - Union[NormalizationConfig, FromStateNormalizer] - ] = None - - def __post_init__(self): - for name in self.next_step_forcing_names: - if name not in self.in_names: - raise ValueError( - f"next_step_forcing_name '{name}' not in in_names: {self.in_names}" - ) - if name in self.out_names: - raise ValueError( - f"next_step_forcing_name is an output variable: '{name}'" - ) - if ( - self.residual_normalization is not None - and self.loss_normalization is not None - ): - raise ValueError( - "Only one of residual_normalization, loss_normalization can " - "be provided." - "If residual_normalization is provided, it will be used for all " - "*prognostic* variables in loss scalng. " - "If loss_normalization is provided, it will be used for all variables " - "in loss scaling." - ) - - def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: - return DataRequirements( - names=self.all_names, - n_timesteps=n_forward_steps + 1, - ) - - def get_forcing_data_requirements(self, n_forward_steps: int) -> DataRequirements: - if self.ocean is None: - names = self.forcing_names - else: - names = list(set(self.forcing_names).union(self.ocean.forcing_names)) - - return DataRequirements(names=names, n_timesteps=n_forward_steps + 1) - - def get_state(self): - return dataclasses.asdict(self) - - def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: - """ - If the model is being initialized from another model's weights for fine-tuning, - returns those weights. Otherwise, returns None. - - The list mirrors the order of `modules` in the `SingleModuleStepper` class. - """ - base_weights = self.parameter_init.get_base_weights() - if base_weights is not None: - return [base_weights] - else: - return None - - def get_stepper( - self, - img_shape: Tuple[int, int], - area: Optional[torch.Tensor], - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - ): - return SingleModuleStepper( - config=self, - img_shape=img_shape, - area=area, - sigma_coordinates=sigma_coordinates, - timestep=timestep, - ) - - @classmethod - def from_state(cls, state) -> "SingleModuleStepperConfig": - state = cls.remove_deprecated_keys(state) - return dacite.from_dict( - data_class=cls, data=state, config=dacite.Config(strict=True) - ) - - @property - def all_names(self): - """Names of all variables required, including auxiliary ones.""" - extra_names = [] - if self.ocean is not None: - extra_names.extend(self.ocean.forcing_names) - all_names = list(set(self.in_names).union(self.out_names).union(extra_names)) - return all_names - - @property - def normalize_names(self): - """Names of variables which require normalization. I.e. inputs/outputs.""" - return list(set(self.in_names).union(self.out_names)) - - @property - def forcing_names(self) -> List[str]: - """Names of variables which are inputs only.""" - return list(set(self.in_names) - set(self.out_names)) - - @property - def prognostic_names(self) -> List[str]: - """Names of variables which both inputs and outputs.""" - return list(set(self.out_names).intersection(self.in_names)) - - @classmethod - def remove_deprecated_keys(cls, state: Dict[str, Any]) -> Dict[str, Any]: - _unsupported_key_defaults = { - "conserve_dry_air": False, - "optimization": None, - "conservation_loss": {"dry_air_penalty": None}, - } - state_copy = state.copy() - for key, default in _unsupported_key_defaults.items(): - if key in state_copy: - if state_copy[key] == default or state_copy[key] is None: - del state_copy[key] - else: - raise ValueError( - f"The stepper config option {key} is deprecated and the setting" - f" provided, {state_copy[key]}, is no longer implemented. The " - "SingleModuleStepper being loaded from state cannot be run by " - "this version of the code." - ) - if "prescriber" in state_copy: - # want to maintain backwards compatibility for this particular feature - if state_copy["prescriber"] is not None: - if state_copy.get("ocean") is not None: - raise ValueError("Cannot specify both prescriber and ocean.") - state_copy["ocean"] = { - "surface_temperature_name": state_copy["prescriber"][ - "prescribed_name" - ], - "ocean_fraction_name": state_copy["prescriber"]["mask_name"], - "interpolate": state_copy["prescriber"]["interpolate"], - } - del state_copy["prescriber"] - return state_copy - - -@dataclasses.dataclass -class ExistingStepperConfig: - """ - Configuration for an existing stepper. This is only designed to point to - a serialized stepper checkpoint for loading, e.g., in the case of training - resumption. - - Attributes: - checkpoint_path: The path to the serialized checkpoint. - """ - - checkpoint_path: str - - def _load_checkpoint(self) -> Mapping[str, Any]: - return torch.load(self.checkpoint_path, map_location=get_device()) - - def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: - return SingleModuleStepperConfig.from_state( - self._load_checkpoint()["stepper"]["config"] - ).get_data_requirements(n_forward_steps) - - def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: - return SingleModuleStepperConfig.from_state( - self._load_checkpoint()["stepper"]["config"] - ).get_base_weights() - - def get_stepper(self, img_shape, area, sigma_coordinates, timestep): - del img_shape # unused - return SingleModuleStepper.from_state( - self._load_checkpoint()["stepper"], - area=area, - sigma_coordinates=sigma_coordinates, - ) - - -def _combine_normalizers( - residual_normalizer: StandardNormalizer, - model_normalizer: StandardNormalizer, -) -> StandardNormalizer: - # Combine residual and model normalizers by overwriting the model normalizer - # values that are present in residual normalizer. The residual normalizer - # is assumed to have a subset of prognostic keys only. - means, stds = copy(model_normalizer.means), copy(model_normalizer.stds) - means.update(residual_normalizer.means) - stds.update(residual_normalizer.stds) - return StandardNormalizer(means=means, stds=stds) - - -def _cast_tensordict( - data: TensorDict, dtype: Optional[torch.dtype] = None -) -> TensorDict: - device = get_device() - return {name: value.to(device, dtype=dtype) for name, value in data.items()} - - -def _prepend_timestep( - data: TensorDict, timestep: TensorDict, time_dim: int = 1 -) -> TensorDict: - return { - k: torch.cat([timestep[k].unsqueeze(time_dim), v], dim=time_dim) - for k, v in data.items() - } - - -@dataclasses.dataclass -class SteppedData: - metrics: TensorDict - gen_data: TensorDict - target_data: TensorDict - gen_data_norm: TensorDict - target_data_norm: TensorDict - - def remove_initial_condition(self) -> "SteppedData": - return SteppedData( - metrics=self.metrics, - gen_data={k: v[:, 1:] for k, v in self.gen_data.items()}, - target_data={k: v[:, 1:] for k, v in self.target_data.items()}, - gen_data_norm={k: v[:, 1:] for k, v in self.gen_data_norm.items()}, - target_data_norm={k: v[:, 1:] for k, v in self.target_data_norm.items()}, - ) - - def copy(self) -> "SteppedData": - """Creates new dictionaries for the data but with the same tensors.""" - return SteppedData( - metrics=self.metrics, - gen_data={k: v for k, v in self.gen_data.items()}, - target_data={k: v for k, v in self.target_data.items()}, - gen_data_norm={k: v for k, v in self.gen_data_norm.items()}, - target_data_norm={k: v for k, v in self.target_data_norm.items()}, - ) - - def prepend_initial_condition( - self, - initial_condition: TensorDict, - normalized_initial_condition: TensorDict, - ) -> "SteppedData": - """ - Prepends an initial condition to the existing stepped data. - """ - initial_condition = _cast_tensordict(initial_condition) - normalized_initial_condition = _cast_tensordict(normalized_initial_condition) - - return SteppedData( - metrics=self.metrics, - gen_data=_prepend_timestep(self.gen_data, initial_condition), - target_data=_prepend_timestep(self.target_data, initial_condition), - gen_data_norm=_prepend_timestep( - self.gen_data_norm, normalized_initial_condition - ), - target_data_norm=_prepend_timestep( - self.target_data_norm, normalized_initial_condition - ), - ) - - -class SingleModuleStepper: - """ - Stepper class for a single pytorch module. - """ - - TIME_DIM = 1 - CHANNEL_DIM = -3 - - def __init__( - self, - config: SingleModuleStepperConfig, - img_shape: Tuple[int, int], - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - timestep: datetime.timedelta, - init_weights: bool = True, - ): - """ - Args: - config: The configuration. - img_shape: Shape of domain as (n_lat, n_lon). - area: (n_lat, n_lon) array containing relative gridcell area, - in any units including unitless. - sigma_coordinates: The sigma coordinates. - timestep: Timestep of the model. - init_weights: Whether to initialize the weights. Should pass False if - the weights are about to be overwritten by a checkpoint. - """ - n_in_channels = len(config.in_names) - n_out_channels = len(config.out_names) - self.in_packer = Packer(config.in_names) - self.out_packer = Packer(config.out_names) - self.normalizer = config.normalization.build(config.normalize_names) - if config.ocean is not None: - self.ocean: Optional[Ocean] = config.ocean.build( - config.in_names, config.out_names, timestep - ) - else: - self.ocean = None - self.module = config.builder.build( - n_in_channels=n_in_channels, - n_out_channels=n_out_channels, - img_shape=img_shape, - ) - module, self._l2_sp_tuning_regularizer = config.parameter_init.apply( - self.module, init_weights=init_weights - ) - self.module = module.to(get_device()) - - self._img_shape = img_shape - self._config = config - self._no_optimization = NullOptimization() - - dist = Distributed.get_instance() - self._is_distributed = dist.is_distributed() - self.module = dist.wrap_module(self.module) - - self.area = area.to(get_device()) - self.sigma_coordinates = sigma_coordinates.to(get_device()) - self.timestep = timestep - - self.loss_obj = config.loss.build(self.area, config.out_names, self.CHANNEL_DIM) - - self._corrector = config.corrector.build( - area=self.area, - sigma_coordinates=self.sigma_coordinates, - timestep=timestep, - ) - if config.loss_normalization is not None: - self.loss_normalizer = config.loss_normalization.build( - names=config.normalize_names - ) - elif config.residual_normalization is not None: - # Use residual norm for prognostic variables and input/output - # normalizer for diagnostic variables in loss - self.loss_normalizer = _combine_normalizers( - residual_normalizer=config.residual_normalization.build( - config.prognostic_names - ), - model_normalizer=self.normalizer, - ) - else: - self.loss_normalizer = self.normalizer - - def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: - return self._config.get_data_requirements(n_forward_steps) - - @property - def effective_loss_scaling(self) -> TensorMapping: - """ - Effective loss scalings used to normalize outputs before computing loss. - y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i - where loss_scaling_i = loss_normalizer_std_i / weight_i - """ - custom_weights = self._config.loss.weights - loss_normalizer_stds = self.loss_normalizer.stds - return { - k: loss_normalizer_stds[k] / custom_weights.get(k, 1.0) - for k in self._config.out_names - } - - @property - def prognostic_names(self) -> List[str]: - return sorted( - list(set(self.out_packer.names).intersection(self.in_packer.names)) - ) - - @property - def modules(self) -> nn.ModuleList: - """ - Returns: - A list of modules being trained. - """ - return nn.ModuleList([self.module]) - - def step( - self, - input: TensorMapping, - ocean_data: TensorMapping, - ) -> TensorDict: - """ - Step the model forward one timestep given input data. - - Args: - input: Mapping from variable name to tensor of shape - [n_batch, n_lat, n_lon]. This data is used as input for `self.module` - and is assumed to contain all input variables and be denormalized. - ocean_data: Mapping from variable name to tensor of shape - [n_batch, n_lat, n_lon]. This must contain the necessary data at the - output timestep for the ocean model (e.g. surface temperature, - mixed-layer depth etc.). - - Returns: - The denormalized output data at the next time step. - """ - input_norm = self.normalizer.normalize(input) - input_tensor = self.in_packer.pack(input_norm, axis=self.CHANNEL_DIM) - output_tensor = self.module(input_tensor) - output_norm = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM) - output = self.normalizer.denormalize(output_norm) - if self._corrector is not None: - output = self._corrector(input, output) - if self.ocean is not None: - output = self.ocean(ocean_data, input, output) - return output - - def predict( - self, - initial_condition: TensorMapping, - forcing_data: TensorMapping, - n_forward_steps: int, - ) -> TensorDict: - """ - Predict multiple steps forward given initial condition and forcing data. - - Args: - initial_condition: Mapping from variable name to tensors of shape - [n_batch, n_lat, n_lon]. This data is assumed to contain all prognostic - variables and be denormalized. - forcing_data: Mapping from variable name to tensors of shape - [n_batch, n_forward_steps + 1, n_lat, n_lon]. This contains the forcing - and ocean data for the initial condition and all subsequent timesteps. - n_forward_steps: The number of timesteps to run the model forward for. - - Returns: - The denormalized output data for all the forward timesteps. Shape of - each tensor will be [n_batch, n_forward_steps, n_lat, n_lon]. - """ - output_list = [] - state = initial_condition - forcing_names = self._config.forcing_names - ocean_forcing_names = self.ocean.forcing_names if self.ocean is not None else [] - for step in range(n_forward_steps): - current_step_forcing = { - k: ( - forcing_data[k][:, step] - if k not in self._config.next_step_forcing_names - else forcing_data[k][:, step + 1] - ) - for k in forcing_names - } - next_step_ocean_data = { - k: forcing_data[k][:, step + 1] for k in ocean_forcing_names - } - input_data = {**state, **current_step_forcing} - state = self.step(input_data, next_step_ocean_data) - output_list.append(state) - output_timeseries = {} - for name in state: - output_timeseries[name] = torch.stack( - [x[name] for x in output_list], dim=self.TIME_DIM - ) - return output_timeseries - - def get_initial_condition(self, data: TensorDict) -> Tuple[TensorDict, TensorDict]: - ic = {k: v.select(self.TIME_DIM, 0) for k, v in data.items()} - return ic, self.normalizer.normalize(ic) - - def run_on_batch( - self, - data: TensorDict, - optimization: Union[Optimization, NullOptimization], - n_forward_steps: int = 1, - ) -> SteppedData: - """ - Step the model forward multiple steps on a batch of data. - - Args: - data: The batch data where each tensor has shape - [n_sample, n_forward_steps + 1, n_lat, n_lon]. - optimization: The optimization class to use for updating the module. - Use `NullOptimization` to disable training. - n_forward_steps: The number of timesteps to run the model for. - aggregator: The data aggregator. - - Returns: - The loss metrics, the generated data, the normalized generated data, - and the normalized batch data. - """ - data = _cast_tensordict(data, dtype=torch.float) - time_dim = self.TIME_DIM - if self.ocean is None: - forcing_names = self._config.forcing_names - else: - forcing_names = self._config.forcing_names + self.ocean.forcing_names - forcing_data = {k: data[k] for k in forcing_names} - - loss = torch.tensor(0.0, device=get_device()) - metrics = {} - - input_data = { - k: data[k].select(time_dim, 0) for k in self._config.prognostic_names - } - # Remove the initial condition from target data - data = {k: data[k][:, 1:] for k in data} - - optimization.set_mode(self.module) - with optimization.autocast(): - # output from self.predict does not include initial condition - gen_data = self.predict(input_data, forcing_data, n_forward_steps) - - # compute loss for each timestep - for step in range(n_forward_steps): - gen_step = {k: v.select(time_dim, step) for k, v in gen_data.items()} - target_step = {k: v.select(time_dim, step) for k, v in data.items()} - gen_norm_step = self.loss_normalizer.normalize(gen_step) - target_norm_step = self.loss_normalizer.normalize(target_step) - - step_loss = self.loss_obj(gen_norm_step, target_norm_step) - loss += step_loss - metrics[f"loss_step_{step}"] = step_loss.detach() - - loss += self._l2_sp_tuning_regularizer() - - metrics["loss"] = loss.detach() - optimization.step_weights(loss) - - gen_data_norm = self.normalizer.normalize(gen_data) - full_data_norm = self.normalizer.normalize(data) - - return SteppedData( - metrics=metrics, - gen_data=gen_data, - target_data=data, - gen_data_norm=gen_data_norm, - target_data_norm=full_data_norm, - ) - - def get_state(self): - """ - Returns: - The state of the stepper. - """ - return { - "module": self.module.state_dict(), - "normalizer": self.normalizer.get_state(), - "img_shape": self._img_shape, - "config": self._config.get_state(), - "area": self.area, - "sigma_coordinates": self.sigma_coordinates.as_dict(), - "encoded_timestep": encode_timestep(self.timestep), - "loss_normalizer": self.loss_normalizer.get_state(), - } - - def load_state(self, state): - """ - Load the state of the stepper. - - Args: - state: The state to load. - """ - if "module" in state: - module = state["module"] - if "module.device_buffer" in module: - # for backwards compatibility with old checkpoints - del module["module.device_buffer"] - self.module.load_state_dict(module) - - @classmethod - def from_state( - cls, - state, - area: torch.Tensor, - sigma_coordinates: SigmaCoordinates, - ) -> "SingleModuleStepper": - """ - Load the state of the stepper. - - Args: - state: The state to load. - area: (n_lat, n_lon) array containing relative gridcell area, in any - units including unitless. - sigma_coordinates: The sigma coordinates. - - Returns: - The stepper. - """ - config = {**state["config"]} # make a copy to avoid mutating input - config["normalization"] = FromStateNormalizer(state["normalizer"]) - - # for backwards compatibility with previous steppers created w/o - # loss_normalization or residual_normalization - loss_normalizer_state = state.get("loss_normalizer", state["normalizer"]) - config["loss_normalization"] = FromStateNormalizer(loss_normalizer_state) - # Overwrite the residual_normalization key if it exists, since the combined - # loss scalings are saved in initial training as the loss_normalization - config["residual_normalization"] = None - - area = state.get("area", area) - if "sigma_coordinates" in state: - sigma_coordinates = dacite.from_dict( - data_class=SigmaCoordinates, - data=state["sigma_coordinates"], - config=dacite.Config(strict=True), - ) - encoded_timestep = state.get("encoded_timestep", DEFAULT_ENCODED_TIMESTEP) - timestep = decode_timestep(encoded_timestep) - if "img_shape" in state: - img_shape = state["img_shape"] - else: - # this is for backwards compatibility with old checkpoints - for v in state["data_shapes"].values(): - img_shape = v[-2:] - break - stepper = cls( - config=SingleModuleStepperConfig.from_state(config), - img_shape=img_shape, - area=area, - sigma_coordinates=sigma_coordinates, - timestep=timestep, - # don't need to initialize weights, we're about to load_state - init_weights=False, - ) - stepper.load_state(state) - return stepper diff --git a/fme/fme/core/test_climate_data.py b/fme/fme/core/test_climate_data.py index 5733de9..36db0ed 100644 --- a/fme/fme/core/test_climate_data.py +++ b/fme/fme/core/test_climate_data.py @@ -3,23 +3,7 @@ import pytest import torch -from fme.core.climate_data import ( - ClimateData, - _height_at_interface, - _layer_thickness, - _pressure_at_interface, - natural_sort, -) - - -def test__pressure_at_interface(): - ak = torch.tensor([2.0, 0.5, 0.0]) - bk = torch.tensor([0.0, 0.5, 1.0]) - psfc = torch.tensor([[1, 1], [2, 2]]) - pinterface = _pressure_at_interface(ak=ak, bk=bk, surface_pressure=psfc) - assert pinterface.shape == (2, 2, 3) - assert pinterface[0, 0, 0] == ak[0] - assert pinterface[0, 0, -1] == bk[-1] * psfc[0, 0] +from fme.core.climate_data import ClimateData, _height_at_interface, _layer_thickness def test__layer_thickness(): @@ -54,58 +38,6 @@ def test__height_at_interface(): ) -@pytest.mark.parametrize( - "names, sorted_names", - [ - ( - ["a_1", "b_1", "c_1", "a_2"], - [ - "a_1", - "a_2", - "b_1", - "c_1", - ], - ), - ( - [ - "a_0", - "a_1", - "a_12", - "a_2", - ], - [ - "a_0", - "a_1", - "a_2", - "a_12", - ], - ), - ( - [ - "a_0001", - "a_0012", - "a_0002", - ], - [ - "a_0001", - "a_0002", - "a_0012", - ], - ), - ( - [ - "ab1", - "aa10", - "aa2", - ], - ["aa2", "aa10", "ab1"], - ), - ], -) -def test_natural_sort(names, sorted_names): - assert natural_sort(names) == sorted_names - - @pytest.mark.parametrize("has_water_variable", [True, False]) def test_missing_specific_total_water(has_water_variable): """Check shape of specific total water and make sure that it returns None @@ -176,5 +108,5 @@ def _get_data(missing_water_layer: bool): 2, ) else: - with pytest.raises(KeyError): + with pytest.raises(ValueError): _ = climate_data.specific_total_water diff --git a/fme/fme/core/test_coordinates.py b/fme/fme/core/test_coordinates.py new file mode 100644 index 0000000..04c3a87 --- /dev/null +++ b/fme/fme/core/test_coordinates.py @@ -0,0 +1,115 @@ +import pytest +import torch + +from fme.core.coordinates import ( + HEALPixCoordinates, + HybridSigmaPressureCoordinate, + LatLonCoordinates, +) + + +@pytest.mark.parametrize( + "first, second", + [ + ( + HybridSigmaPressureCoordinate( + ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6]) + ), + HybridSigmaPressureCoordinate( + ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6]) + ), + ), + ( + LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])), + LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])), + ), + ( + HEALPixCoordinates( + face=torch.tensor([1, 2, 3]), + height=torch.tensor([4, 5, 6]), + width=torch.tensor([7, 8, 9]), + ), + HEALPixCoordinates( + face=torch.tensor([1, 2, 3]), + height=torch.tensor([4, 5, 6]), + width=torch.tensor([7, 8, 9]), + ), + ), + ], +) +def test_equality(first, second): + assert first == second + + +@pytest.mark.parametrize( + "first, second", + [ + ( + HybridSigmaPressureCoordinate( + ak=torch.tensor([1, 2, 3]), bk=torch.tensor([4, 5, 6]) + ), + HybridSigmaPressureCoordinate( + ak=torch.tensor([1, 2, 3]), bk=torch.tensor([5, 6, 7]) + ), + ), + ( + LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])), + LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([5, 6, 7])), + ), + ( + HEALPixCoordinates( + face=torch.tensor([1, 2, 3]), + height=torch.tensor([4, 5, 6]), + width=torch.tensor([7, 8, 9]), + ), + HEALPixCoordinates( + face=torch.tensor([1, 2, 3]), + height=torch.tensor([4, 5, 6]), + width=torch.tensor([8, 9, 10]), + ), + ), + ( + LatLonCoordinates(lat=torch.tensor([1, 2, 3]), lon=torch.tensor([4, 5, 6])), + HEALPixCoordinates( + face=torch.tensor([1, 2, 3]), + height=torch.tensor([4, 5, 6]), + width=torch.tensor([7, 8, 9]), + ), + ), + ], +) +def test_inequality(first, second): + assert first != second + + +def test_vertical_integral_shape(): + nlat, nlon, nz = 4, 8, 3 + water = torch.rand(nlat, nlon, nz) + pressure = torch.rand(nlat, nlon) + ak, bk = torch.arange(nz + 1), torch.arange(nz + 1) + coords = HybridSigmaPressureCoordinate(ak, bk) + water_path = coords.vertical_integral(water, pressure) + assert water_path.shape == (nlat, nlon) + + +def test_vertical_coordinates_raises_value_error(): + ak, bk = torch.arange(3), torch.arange(4) + with pytest.raises(ValueError): + HybridSigmaPressureCoordinate(ak, bk) + + +def test_vertical_coordinates_len(): + ak, bk = torch.arange(3), torch.arange(3) + coords = HybridSigmaPressureCoordinate(ak, bk) + assert len(coords) == 3 + + +def test_interface_pressure(): + ak = torch.tensor([2.0, 0.5, 0.0]) + bk = torch.tensor([0.0, 0.5, 1.0]) + psfc = torch.tensor([[1, 1], [2, 2]]) + coords = HybridSigmaPressureCoordinate(ak, bk) + pinterface = coords.interface_pressure(psfc) + assert pinterface.shape == (2, 2, 3) + assert pinterface[0, 0, 0] == ak[0] + assert pinterface[0, 0, -1] == bk[-1] * psfc[0, 0] diff --git a/fme/fme/core/test_corrector.py b/fme/fme/core/test_corrector.py deleted file mode 100644 index 72a7b46..0000000 --- a/fme/fme/core/test_corrector.py +++ /dev/null @@ -1,199 +0,0 @@ -import datetime - -import numpy as np -import pytest -import torch - -from fme.ace.inference.derived_variables import total_water_path_budget_residual -from fme.core import ClimateData, metrics -from fme.core.corrector import ( - _force_conserve_dry_air, - _force_conserve_moisture, - _force_positive, - _force_zero_global_mean_moisture_advection, -) -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.loss import get_dry_air_nonconservation - -TIMESTEP = datetime.timedelta(hours=6) - - -def test_force_no_global_mean_moisture_advection(): - torch.random.manual_seed(0) - data = { - "tendency_of_total_water_path_due_to_advection": torch.rand(size=(3, 2, 5, 5)), - } - area_weights = 1.0 + torch.rand(size=(5, 5)) - original_mean = metrics.weighted_mean( - data["tendency_of_total_water_path_due_to_advection"], - weights=area_weights, - dim=[-2, -1], - ) - assert (original_mean.abs() > 0.0).all() - fixed_data = _force_zero_global_mean_moisture_advection( - data, - area=area_weights, - ) - new_mean = metrics.weighted_mean( - fixed_data["tendency_of_total_water_path_due_to_advection"], - weights=area_weights, - dim=[-2, -1], - ) - assert (new_mean.abs() < original_mean.abs()).all() - np.testing.assert_almost_equal(new_mean.cpu().numpy(), 0.0, decimal=6) - - -def test_force_conserve_dry_air(): - torch.random.manual_seed(0) - data = { - "PRESsfc": 10.0 + torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)), - } - sigma_coordinates = SigmaCoordinates( - ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0]) - ) - area_weights = 1.0 + torch.rand(size=(5, 5)) - original_nonconservation = get_dry_air_nonconservation( - data, - sigma_coordinates=sigma_coordinates, - area_weights=area_weights, - ) - assert original_nonconservation > 0.0 - in_data = {k: v.select(dim=1, index=0) for k, v in data.items()} - out_data = {k: v.select(dim=1, index=1) for k, v in data.items()} - fixed_out_data = _force_conserve_dry_air( - in_data, - out_data, - sigma_coordinates=sigma_coordinates, - area=area_weights, - ) - new_data = { - k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items() - } - new_nonconservation = get_dry_air_nonconservation( - new_data, - sigma_coordinates=sigma_coordinates, - area_weights=area_weights, - ) - assert new_nonconservation < original_nonconservation - np.testing.assert_almost_equal(new_nonconservation.cpu().numpy(), 0.0, decimal=6) - - -@pytest.mark.parametrize("fv3_data", [True, False]) -@pytest.mark.parametrize( - "global_only, terms_to_modify", - [ - (True, "precipitation"), - (True, "evaporation"), - (False, "advection_and_precipitation"), - (False, "advection_and_evaporation"), - ], -) -def test_force_conserve_moisture(fv3_data: bool, global_only: bool, terms_to_modify): - torch.random.manual_seed(0) - if fv3_data: - data = { - "PRESsfc": 10.0 + torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)), - "PRATEsfc": torch.rand(size=(3, 2, 5, 5)), - "LHTFLsfc": torch.rand(size=(3, 2, 5, 5)), - "tendency_of_total_water_path_due_to_advection": torch.rand( - size=(3, 2, 5, 5) - ), - } - else: - data = { - "PS": 10.0 + torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_0": torch.rand(size=(3, 2, 5, 5)), - "specific_total_water_1": torch.rand(size=(3, 2, 5, 5)), - "surface_precipitation_rate": torch.rand(size=(3, 2, 5, 5)), - "LHFLX": torch.rand(size=(3, 2, 5, 5)), - "tendency_of_total_water_path_due_to_advection": torch.rand( - size=(3, 2, 5, 5) - ), - } - sigma_coordinates = SigmaCoordinates( - ak=torch.asarray([3.0, 1.0, 0.0]), bk=torch.asarray([0.0, 0.6, 1.0]) - ) - area_weights = 1.0 + torch.rand(size=(5, 5)) - data["tendency_of_total_water_path_due_to_advection"] -= metrics.weighted_mean( - data["tendency_of_total_water_path_due_to_advection"], - weights=area_weights, - dim=[-2, -1], - )[..., None, None] - original_budget_residual = total_water_path_budget_residual( - ClimateData(data), - sigma_coordinates=sigma_coordinates, - timestep=TIMESTEP, - )[ - :, 1 - ] # no meaning for initial value data, want first timestep - if global_only: - original_budget_residual = metrics.weighted_mean( - original_budget_residual, weights=area_weights, dim=[-2, -1] - ) - original_budget_residual = original_budget_residual.cpu().numpy() - original_dry_air = ( - ClimateData(data) - .surface_pressure_due_to_dry_air(sigma_coordinates) - .cpu() - .numpy() - ) - assert np.any(np.abs(original_budget_residual) > 0.0) - in_data = {k: v.select(dim=1, index=0) for k, v in data.items()} - out_data = {k: v.select(dim=1, index=1) for k, v in data.items()} - fixed_out_data = _force_conserve_moisture( - in_data, - out_data, - sigma_coordinates=sigma_coordinates, - area=area_weights, - timestep=TIMESTEP, - terms_to_modify=terms_to_modify, - ) - new_data = { - k: torch.stack([v, fixed_out_data[k]], dim=1) for k, v in in_data.items() - } - new_budget_residual = total_water_path_budget_residual( - ClimateData(new_data), - sigma_coordinates=sigma_coordinates, - timestep=TIMESTEP, - )[ - :, 1 - ] # no meaning for initial value data, want first timestep - new_dry_air = ( - ClimateData(data) - .surface_pressure_due_to_dry_air(sigma_coordinates) - .cpu() - .numpy() - ) - - global_budget_residual = ( - metrics.weighted_mean(new_budget_residual, weights=area_weights, dim=[-2, -1]) - .cpu() - .numpy() - ) - np.testing.assert_almost_equal(global_budget_residual, 0.0, decimal=6) - - if not global_only: - new_budget_residual = new_budget_residual.cpu().numpy() - assert np.all(np.abs(new_budget_residual) < np.abs(original_budget_residual)) - np.testing.assert_almost_equal(new_budget_residual, 0.0, decimal=6) - - np.testing.assert_almost_equal(new_dry_air, original_dry_air, decimal=6) - - -def test__force_positive(): - data = { - "foo": torch.tensor([[-1.0, 0.0], [0.0, 1.0], [1.0, 2.0]]), - "bar": torch.tensor([[-1.0, 0.0], [0.0, -3.0], [1.0, 2.0]]), - } - original_min = torch.min(data["foo"]) - assert original_min < 0.0 - fixed_data = _force_positive(data, ["foo"]) - new_min = torch.min(fixed_data["foo"]) - # Ensure the minimum value of 'foo' is now 0 - torch.testing.assert_close(new_min, torch.tensor(0.0)) - # Ensure other variables are not modified - torch.testing.assert_close(fixed_data["bar"], data["bar"]) diff --git a/fme/fme/core/test_distributed.py b/fme/fme/core/test_distributed.py index b904dcc..4105c77 100644 --- a/fme/fme/core/test_distributed.py +++ b/fme/fme/core/test_distributed.py @@ -1,6 +1,8 @@ import pytest import torch +from fme import get_device + from .distributed import pad_tensor_at_end, unpad_tensor_at_end @@ -32,7 +34,7 @@ def test_pad_tensor_at_end(padding, fill_value): ], ) def test_pad_unpad_rountrip(padding): - tensor = torch.ones(2, 3, 4) + tensor = torch.ones(2, 3, 4, device=get_device()) padded_tensor = pad_tensor_at_end(tensor, padding) unpadded_tensor = unpad_tensor_at_end(padded_tensor, padding) assert unpadded_tensor.size() == tensor.size() diff --git a/fme/fme/core/test_gridded_ops.py b/fme/fme/core/test_gridded_ops.py new file mode 100644 index 0000000..910fcf6 --- /dev/null +++ b/fme/fme/core/test_gridded_ops.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, Type + +import pytest +import torch + +from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations + + +@pytest.mark.parametrize( + "state, expected_class", + [ + ( + { + "type": "LatLonOperations", + "state": {"area_weights": torch.tensor([1.0, 2.0])}, + }, + LatLonOperations, + ), + ( + { + "type": "HEALPixOperations", + "state": {}, + }, + HEALPixOperations, + ), + ], +) +def test_gridded_operations_from_state( + state: Dict[str, Any], + expected_class: Type[GriddedOperations], +): + ops = GriddedOperations.from_state(state) + assert isinstance(ops, expected_class) + + recovered_state = ops.to_state() + assert recovered_state == state + + with pytest.raises(RuntimeError): + expected_class.from_state(state["state"]) diff --git a/fme/fme/core/test_histogram.py b/fme/fme/core/test_histogram.py index e81c500..dfb1339 100644 --- a/fme/fme/core/test_histogram.py +++ b/fme/fme/core/test_histogram.py @@ -79,20 +79,28 @@ def test_dynamic_histogram_random_values(n_times: int, time_bin_len: int): def test_dynamic_histogram_extends_as_expected(): histogram = DynamicHistogram(n_times=1, n_bins=200) histogram.add(torch.as_tensor([[-1.0, 0.0, 1.0]])) - np.testing.assert_approx_equal(histogram.bin_edges[0], -1.0, significant=6) - np.testing.assert_approx_equal(histogram.bin_edges[-1], 1.0, significant=6) + bin_edges = histogram.bin_edges + assert bin_edges is not None + np.testing.assert_approx_equal(bin_edges[0], -1.0, significant=6) + np.testing.assert_approx_equal(bin_edges[-1], 1.0, significant=6) histogram.add(torch.as_tensor([[-2.0]])) + bin_edges = histogram.bin_edges + assert bin_edges is not None # double in size to the left, length becomes 4, from -3 to 1.0 - np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6) - np.testing.assert_approx_equal(histogram.bin_edges[-1], 1.0, significant=6) + np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6) + np.testing.assert_approx_equal(bin_edges[-1], 1.0, significant=6) histogram.add(torch.as_tensor([[2.0]])) + bin_edges = histogram.bin_edges + assert bin_edges is not None # double in size to the right, length becomes 8, from -3 to 5.0 - np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6) - np.testing.assert_approx_equal(histogram.bin_edges[-1], 5.0, significant=6) + np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6) + np.testing.assert_approx_equal(bin_edges[-1], 5.0, significant=6) histogram.add(torch.as_tensor([[27.0]])) + bin_edges = histogram.bin_edges + assert bin_edges is not None # double in size twice to the right, length becomes 32, from -3 to 29.0 - np.testing.assert_approx_equal(histogram.bin_edges[0], -3.0, significant=6) - np.testing.assert_approx_equal(histogram.bin_edges[-1], 29.0, significant=6) + np.testing.assert_approx_equal(bin_edges[0], -3.0, significant=6) + np.testing.assert_approx_equal(bin_edges[-1], 29.0, significant=6) def test_histogram_handles_uniform_field(): @@ -134,8 +142,8 @@ def test_compared_dynamic_histograms(shape, percentiles): for data_type in ["target", "prediction"]: assert isinstance(wandb_result[f"{var_name}"], matplotlib.figure.Figure) for p in percentiles: - assert ( - wandb_result[f"{data_type}/{p}th-percentile/{var_name}"].shape == () + assert isinstance( + wandb_result[f"{data_type}/{p}th-percentile/{var_name}"], float ) ds = histogram.get_dataset() diff --git a/fme/fme/core/test_loss.py b/fme/fme/core/test_loss.py index dbb927c..52952d0 100644 --- a/fme/fme/core/test_loss.py +++ b/fme/fme/core/test_loss.py @@ -3,6 +3,7 @@ from fme.core import metrics from fme.core.device import get_device +from fme.core.gridded_ops import LatLonOperations from fme.core.loss import ( AreaWeightedMSELoss, GlobalMeanLoss, @@ -19,7 +20,7 @@ def test_loss_builds_and_runs(global_mean_type): config = LossConfig(global_mean_type=global_mean_type) area = torch.randn(10, 10, device=get_device()) - loss = config.build(area, reduction="mean") + loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean") x = torch.randn(10, 10, 10, 10, 10, device=get_device()) y = torch.randn(10, 10, 10, 10, 10, device=get_device()) result = loss(x, y) @@ -28,36 +29,47 @@ def test_loss_builds_and_runs(global_mean_type): def test_loss_of_zeros_is_variance(): + torch.manual_seed(0) config = LossConfig(global_mean_type=None) area = torch.randn(10, 10, device=get_device()) - loss = config.build(area, reduction="mean") + loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean") x = torch.zeros(10, 10, 10, 10, 10, device=get_device()) y = torch.randn(10, 10, 10, 10, 10, device=get_device()) result = loss(x, y) assert result.shape == () assert isinstance(result, torch.Tensor) - torch.testing.assert_close(result, y.var()) + tol = {"rtol": 1e-4, "atol": 1e-4} if str(get_device()).startswith("cuda") else {} + torch.testing.assert_close(result, y.var(), **tol) @pytest.mark.parametrize("global_mean_weight", [0.0, 1.0, 5.0]) def test_loss_of_zeros_is_one_plus_global_mean_weight(global_mean_weight: float): + torch.manual_seed(0) config = LossConfig( global_mean_type="LpLoss", global_mean_weight=global_mean_weight ) area = torch.randn(10, 10, device=get_device()) - loss = config.build(area, reduction="mean") + loss = config.build(LatLonOperations(area).area_weighted_mean, reduction="mean") x = torch.zeros(10, 10, 10, 10, 10, device=get_device()) y = torch.randn(10, 10, 10, 10, 10, device=get_device()) result = loss(x, y) assert result.shape == () assert isinstance(result, torch.Tensor) expected = torch.tensor(1.0 + global_mean_weight) - torch.testing.assert_close(result.cpu(), expected, atol=0.01, rtol=0) + tol = ( + {"atol": 0.015, "rtol": 0.01} + if str(get_device()).startswith("cuda") + else {"atol": 0.01, "rtol": 0.0} + ) + torch.testing.assert_close(result.cpu(), expected, **tol) def test_global_mean_loss(): + torch.manual_seed(0) area = torch.randn(10, 10, device=get_device()) - loss = GlobalMeanLoss(area=area, loss=torch.nn.MSELoss()) + loss = GlobalMeanLoss( + LatLonOperations(area).area_weighted_mean, loss=torch.nn.MSELoss() + ) x = torch.zeros(10, 10, 10, 10, 10, device=get_device()) y = torch.randn(10, 10, 10, 10, 10, device=get_device()) result = loss(x, y) @@ -77,7 +89,7 @@ def test_area_weighted_mse(): x = torch.rand(10, 10).to(get_device()) target = torch.rand(10, 10).to(get_device()) area = torch.rand(10, 10).to(get_device()) - area_weighted_mse = AreaWeightedMSELoss(area) + area_weighted_mse = AreaWeightedMSELoss(LatLonOperations(area).area_weighted_mean) result = area_weighted_mse(x, target) expected = metrics.weighted_mean( torch.nn.MSELoss(reduction="none")(x, target), weights=area, dim=(-2, -1) @@ -160,9 +172,10 @@ def test_WeightedMappingLossConfig_no_weights(): out_names = [f"var_{i}" for i in range(n_channels)] channel_dim = -3 area = torch.tensor([]) # area not used by this config + area_weighted_mean = LatLonOperations(area).area_weighted_mean mapping_loss_config = WeightedMappingLossConfig() - loss = loss_config.build(area, reduction="mean") - mapping_loss = mapping_loss_config.build(area, out_names, channel_dim) + loss = loss_config.build(area_weighted_mean, reduction="mean") + mapping_loss = mapping_loss_config.build(area_weighted_mean, out_names, channel_dim) packer = Packer(out_names) x_mapping = {name: torch.randn(4, 5, 5).to(get_device()) for name in out_names} @@ -180,7 +193,9 @@ def test_WeightedMappingLossConfig_weights(): mapping_loss_config = WeightedMappingLossConfig( type="MSE", weights={"var_0": 4.0, "var_1": 1.0} ) - mapping_loss = mapping_loss_config.build(area, out_names, channel_dim) + mapping_loss = mapping_loss_config.build( + LatLonOperations(area).area_weighted_mean, out_names, channel_dim + ) x0 = torch.ones(4, 5, 5).to(get_device()) x1 = 2.0 * x0 diff --git a/fme/fme/core/test_masking.py b/fme/fme/core/test_masking.py new file mode 100644 index 0000000..5a32ce8 --- /dev/null +++ b/fme/fme/core/test_masking.py @@ -0,0 +1,120 @@ +import pytest +import torch + +import fme +from fme.core.masking import MaskingConfig +from fme.core.stacker import Stacker + + +def test_masking_config(): + config = MaskingConfig( + mask_name="a", + surface_mask_name="b", + mask_value=1, + fill_value=0.0, + ) + mask = config.build() + assert mask.mask_name == "a" + assert mask.mask_value == 1 + assert mask.fill_value == 0.0 + + with pytest.raises(ValueError) as err: + _ = MaskingConfig( + mask_name="a", + surface_mask_name="b", + mask_value=3, + fill_value=0.0, + ) + assert "mask_value must be either 0 or 1" in str(err.value) + + +_SIZE = (4, 4) + +_MASK_DATA = { + "surface_mask": torch.ones(_SIZE, device=fme.get_device()), + "mask_0": torch.ones(_SIZE, device=fme.get_device()), + "mask_1": torch.ones(_SIZE, device=fme.get_device()), +} +_MASK_DATA["surface_mask"][1, 1] = 0 +_MASK_DATA["mask_0"][0, :] = 0 +_MASK_DATA["mask_1"][1, :] = 0 + + +_DATA = { + "PRESsfc": 10.0 + torch.rand(size=_SIZE, device=fme.get_device()), + "specific_total_water_0": torch.rand(size=_SIZE, device=fme.get_device()), + "specific_total_water_1": torch.rand(size=_SIZE, device=fme.get_device()), +} + + +def test_masking(): + config = MaskingConfig( + mask_name="mask", + surface_mask_name="surface_mask", + mask_value=0, + fill_value=0.0, + ) + mask = config.build() + stacker = Stacker( + { + "surface_pressure": ["PRESsfc", "PS"], + "specific_total_water": ["specific_total_water_"], + } + ) + output = mask(stacker, _DATA, _MASK_DATA) + assert output["PRESsfc"][1, 1] == 0.0 + assert output["PRESsfc"][0, 1] != 0.0 + assert torch.all(output["specific_total_water_0"][0, :] == 0.0) + assert torch.all(output["specific_total_water_1"][1, :] == 0.0) + assert torch.all(output["specific_total_water_1"][0, :] != 0.0) + + +def test_masking_no_3d_masking(): + config = MaskingConfig( + mask_name="surface_mask", + mask_value=0, + fill_value=0.0, + ) + mask = config.build() + stacker = Stacker({"surface_pressure": ["PRESsfc", "PS"]}) + output = mask(stacker, _DATA, _MASK_DATA) + assert output["PRESsfc"][1, 1] == 0.0 + assert output["PRESsfc"][0, 1] != 0.0 + assert torch.all(output["specific_total_water_0"][0, :] != 0.0) + assert torch.all(output["specific_total_water_1"][1, :] != 0.0) + assert torch.all(output["specific_total_water_1"][0, :] != 0.0) + + +def test_masking_no_surface_masking(): + config = MaskingConfig( + mask_name="mask", + mask_value=0, + fill_value=0.0, + ) + mask = config.build() + stacker = Stacker({"specific_total_water": ["specific_total_water_"]}) + output = mask(stacker, _DATA, _MASK_DATA) + assert output["PRESsfc"][1, 1] != 0.0 + assert output["PRESsfc"][0, 1] != 0.0 + assert torch.all(output["specific_total_water_0"][0, :] == 0.0) + assert torch.all(output["specific_total_water_1"][1, :] == 0.0) + assert torch.all(output["specific_total_water_1"][0, :] != 0.0) + + +def test_masking_missing_2d_mask(): + config = MaskingConfig( + mask_name="mask", + mask_value=0, + fill_value=0.0, + ) + mask = config.build() + stacker = Stacker( + { + "surface_pressure": ["PRESsfc", "PS"], + "specific_total_water": ["specific_total_water_"], + } + ) + with pytest.raises(RuntimeError) as err: + _ = mask(stacker, _DATA, _MASK_DATA) + assert "surface_mask_name is None" in str(err.value) + assert "surface_pressure" in str(err.value) diff --git a/fme/fme/core/test_metrics.py b/fme/fme/core/test_metrics.py index fdf29dc..71b9077 100644 --- a/fme/fme/core/test_metrics.py +++ b/fme/fme/core/test_metrics.py @@ -4,12 +4,12 @@ import torch_harmonics import fme +from fme.core.coordinates import HybridSigmaPressureCoordinate from fme.core.metrics import ( net_surface_energy_flux, quantile, spherical_power_spectrum, surface_pressure_due_to_dry_air, - vertical_integral, ) @@ -255,22 +255,13 @@ def test_gradient_magnitude_percent_diff(): torch.testing.assert_close(percent_diff, -100 * torch.ones((5, 2))) -def test_vertical_integral_shape(): - nlat, nlon, nz = 4, 8, 3 - water = torch.rand(nlat, nlon, nz) - pressure = torch.rand(nlat, nlon) - ak, bk = torch.arange(nz + 1), torch.arange(nz + 1) - water_path = vertical_integral(water, pressure, ak, bk) - assert water_path.shape == (nlat, nlon) - - def test_dry_air_shapes(): nlat, nlon, nz = 4, 8, 3 water = torch.rand(nlat, nlon, nz) pressure = torch.rand(nlat, nlon) ak, bk = torch.arange(nz + 1), torch.arange(nz + 1) - - dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk) + coords = HybridSigmaPressureCoordinate(ak, bk) + dry_air = surface_pressure_due_to_dry_air(water, pressure, coords) assert dry_air.shape == (nlat, nlon) @@ -286,8 +277,8 @@ def test_single_level_dry_air_no_water(): water = torch.zeros(nlat, nlon, nz) pressure = torch.rand(nlat, nlon) ak, bk = single_level_ak_bk() - - dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk) + coords = HybridSigmaPressureCoordinate(ak, bk) + dry_air = surface_pressure_due_to_dry_air(water, pressure, coords) np.testing.assert_allclose(dry_air.cpu().numpy(), pressure.cpu().numpy()) @@ -297,8 +288,8 @@ def test_single_level_dry_air_all_water(): water = torch.ones(nlat, nlon, nz) pressure = torch.rand(nlat, nlon) ak, bk = single_level_ak_bk() - - dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk) + coords = HybridSigmaPressureCoordinate(ak, bk) + dry_air = surface_pressure_due_to_dry_air(water, pressure, coords) np.testing.assert_almost_equal(dry_air.cpu().numpy(), 0.0, decimal=6) @@ -309,8 +300,8 @@ def test_single_level_dry_air_some_water(): pressure = torch.rand(nlat, nlon) ak, bk = single_level_ak_bk() target_dry_air = pressure * (1.0 - water[:, :, 0]) - - dry_air = surface_pressure_due_to_dry_air(water, pressure, ak, bk) + coords = HybridSigmaPressureCoordinate(ak, bk) + dry_air = surface_pressure_due_to_dry_air(water, pressure, coords) np.testing.assert_allclose( dry_air.cpu().numpy(), target_dry_air.cpu().numpy(), rtol=1e-5 ) diff --git a/fme/fme/core/test_normalizer.py b/fme/fme/core/test_normalizer.py index 5aaeca4..06d2edb 100644 --- a/fme/fme/core/test_normalizer.py +++ b/fme/fme/core/test_normalizer.py @@ -1,8 +1,8 @@ import pytest import torch +from fme.core.device import move_tensordict_to_device from fme.core.normalizer import NormalizationConfig, StandardNormalizer -from fme.core.typing_ import TensorDict def test_normalize_depends_on_mean(): @@ -48,35 +48,35 @@ def test_denormalize_depends_on_std(): def test_normalize_and_denormalize_random_tensor(): torch.manual_seed(0) # randomly set means and stds - means = {"a": torch.randn(1), "b": torch.randn(1)} - stds = {"a": torch.randn(1), "b": torch.randn(1)} + means = move_tensordict_to_device({"a": torch.randn(1), "b": torch.randn(1)}) + stds = move_tensordict_to_device({"a": torch.randn(1), "b": torch.randn(1)}) normalizer = StandardNormalizer(means=means, stds=stds) - tensors = {"a": torch.randn(10), "b": torch.randn(10)} + tensors = move_tensordict_to_device({"a": torch.randn(10), "b": torch.randn(10)}) denormalized = normalizer.denormalize(normalizer.normalize(tensors)) assert torch.allclose(denormalized["a"], tensors["a"]) assert torch.allclose(denormalized["b"], tensors["b"]) -def test_normalization_config_exclude_names(): - torch.manual_seed(0) +def test_missing_normalization_build_raises_error(): normalization = NormalizationConfig( means={"a": 1.0, "b": 2.0}, stds={"a": 1.0, "b": 1.0}, - exclude_names=["c"], ) - normalizer = normalization.build(["a", "b", "c"]) - tensors = {"a": torch.randn(10), "b": torch.randn(10), "c": torch.randn(10)} - normalized: TensorDict = normalizer.normalize(tensors) - denormalized = normalizer.denormalize(normalized) - assert torch.all(normalized["c"] == tensors["c"]) - assert torch.all(denormalized["c"] == tensors["c"]) + all_names = ["a", "b", "c"] + with pytest.raises(KeyError): + normalization.build(all_names) -def test_missing_normalization_raises_error(): +def test_tensors_with_missing_normalization_stats_get_filtered(): normalization = NormalizationConfig( means={"a": 1.0, "b": 2.0}, stds={"a": 1.0, "b": 1.0}, - ) - all_names = ["a", "b", "c"] - with pytest.raises(KeyError): - normalization.build(all_names) + ).build(["a", "b"]) + sample_input = {"a": torch.zeros(1), "b": torch.zeros(1), "c": torch.zeros(1)} + sample_input = move_tensordict_to_device(sample_input) + + normalized = normalization.normalize(sample_input) + assert "c" not in normalized + + denormalized = normalization.denormalize(sample_input) + assert "c" not in denormalized diff --git a/fme/fme/core/test_ocean.py b/fme/fme/core/test_ocean.py index 0fbd9da..26de5ec 100644 --- a/fme/fme/core/test_ocean.py +++ b/fme/fme/core/test_ocean.py @@ -18,7 +18,7 @@ def test_ocean_prescribed(): target_data = {"sst": torch.tensor([22.0, 25.0]), "of": torch.tensor([0.2, 0.8])} input_data = {"sst": torch.tensor([20.0, 21.0]), "foo": torch.tensor([1, 2])} gen_data = {"sst": torch.tensor([23.0, 26.0]), "foo": torch.tensor([2, 3])} - output_data = ocean(target_data, input_data, gen_data) + output_data = ocean(input_data, gen_data, target_data) expected_output = {"sst": torch.tensor([23.0, 25.0]), "foo": torch.tensor([2, 3])} assert set(output_data) == set(expected_output) for name in output_data: @@ -52,7 +52,7 @@ def test_ocean_slab(): } input_data = {"sst": torch.tensor([20.0])} gen_data = {**fluxes, "sst": torch.tensor([25.0])} - output_data = ocean(target_data, input_data, gen_data) + output_data = ocean(input_data, gen_data, target_data) expected_sst_tendency = mixed_layer_temperature_tendency( expected_net_surface_energy_flux, target_data["qf"], target_data["mld"] ) diff --git a/fme/fme/core/test_parameter_init.py b/fme/fme/core/test_parameter_init.py index 084b4b4..46857dc 100644 --- a/fme/fme/core/test_parameter_init.py +++ b/fme/fme/core/test_parameter_init.py @@ -9,11 +9,11 @@ import pytest import torch +from fme.ace.stepper import SingleModuleStepper, SingleModuleStepperConfig from fme.core import parameter_init -from fme.core.data_loading.data_typing import SigmaCoordinates +from fme.core.coordinates import HybridSigmaPressureCoordinate from fme.core.device import get_device -from fme.core.normalizer import FromStateNormalizer -from fme.core.stepper import SingleModuleStepper, SingleModuleStepperConfig +from fme.core.gridded_ops import LatLonOperations from fme.core.typing_ import TensorMapping from fme.core.wildcard import wildcard_match @@ -32,22 +32,20 @@ def test_builder_with_weights_loads_same_state(tmpdir): "builder": sfno_config_data, "in_names": ["x"], "out_names": ["x"], - "normalization": FromStateNormalizer( - state={ - "means": {"x": np.random.randn(1)}, - "stds": {"x": np.random.randn(1)}, - } - ), + "normalization": { + "means": {"x": np.random.randn(1).item()}, + "stds": {"x": np.random.randn(1).item()}, + }, } area = torch.ones((1, 16, 32)).to(get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to( - get_device() - ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ).to(get_device()) stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data) stepper = stepper_config.get_stepper( img_shape=(16, 32), - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) torch.save( @@ -64,19 +62,17 @@ def test_builder_with_weights_loads_same_state(tmpdir): "parameter_init": parameter_init_config, "in_names": ["x"], "out_names": ["x"], - "normalization": FromStateNormalizer( - state={ - "means": {"x": np.random.randn(1)}, - "stds": {"x": np.random.randn(1)}, - } - ), + "normalization": { + "means": {"x": np.random.randn(1).item()}, + "stds": {"x": np.random.randn(1).item()}, + }, } with_builder_stepper = SingleModuleStepperConfig.from_state( with_builder_stepper_config_data ).get_stepper( img_shape=(16, 32), - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) assert_same_state( @@ -148,7 +144,7 @@ def test_builder_with_weights_sfno_init( """ Integration test for the BuilderWithWeights stepper with a SFNO. """ - with_builder_stepper_config_data, area, sigma_coordinates, stepper = get_config( + with_builder_stepper_config_data, area, vertical_coordinate, stepper = get_config( loaded_shape, extra_built_layer, tmpdir ) if expect_exception: @@ -157,8 +153,8 @@ def test_builder_with_weights_sfno_init( with_builder_stepper_config_data ).get_stepper( img_shape=built_shape, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) else: @@ -166,8 +162,8 @@ def test_builder_with_weights_sfno_init( with_builder_stepper_config_data ).get_stepper( img_shape=built_shape, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) if extra_built_layer: @@ -203,22 +199,20 @@ def get_config( "builder": sfno_config_data, "in_names": ["x"], "out_names": ["x"], - "normalization": FromStateNormalizer( - state={ - "means": {"x": np.random.randn(1)}, - "stds": {"x": np.random.randn(1)}, - } - ), + "normalization": { + "means": {"x": np.random.randn(1).item()}, + "stds": {"x": np.random.randn(1).item()}, + }, } area = torch.ones((1, 16, 32)).to(get_device()) - sigma_coordinates = SigmaCoordinates(ak=torch.arange(7), bk=torch.arange(7)).to( - get_device() - ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ).to(get_device()) stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data) stepper = stepper_config.get_stepper( img_shape=loaded_shape, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) built_sfno_config_data = copy.deepcopy(sfno_config_data) @@ -238,35 +232,31 @@ def get_config( "parameter_init": parameter_init_config, "in_names": ["x"], "out_names": ["x"], - "normalization": FromStateNormalizer( - state={ - "means": {"x": np.random.randn(1)}, - "stds": {"x": np.random.randn(1)}, - } - ), + "normalization": { + "means": {"x": np.random.randn(1).item()}, + "stds": {"x": np.random.randn(1).item()}, + }, } - return with_builder_stepper_config_data, area, sigma_coordinates, stepper + return with_builder_stepper_config_data, area, vertical_coordinate, stepper def test_with_weights_saved_stepper_does_not_need_untuned_weights(tmpdir): img_shape = (16, 32) - with_builder_stepper_config_data, area, sigma_coordinates, stepper = get_config( + with_builder_stepper_config_data, area, vertical_coordinate, stepper = get_config( loaded_shape=img_shape, extra_built_layer=False, tmpdir=tmpdir ) with_builder_stepper = SingleModuleStepperConfig.from_state( with_builder_stepper_config_data ).get_stepper( img_shape=img_shape, - area=area, - sigma_coordinates=sigma_coordinates, + gridded_operations=LatLonOperations(area), + vertical_coordinate=vertical_coordinate, timestep=TIMESTEP, ) stepper_state = with_builder_stepper.get_state() # should be able to initialize stepper from its state without the untuned weights (tmpdir / "weights.ckpt").remove() - stepper = SingleModuleStepper.from_state( - stepper_state, area=area, sigma_coordinates=sigma_coordinates - ) + stepper = SingleModuleStepper.from_state(stepper_state) assert isinstance(stepper, SingleModuleStepper) diff --git a/fme/fme/core/test_registry.py b/fme/fme/core/test_registry.py deleted file mode 100644 index 006febf..0000000 --- a/fme/fme/core/test_registry.py +++ /dev/null @@ -1,52 +0,0 @@ -import dataclasses -from typing import Iterable, Tuple - -import pytest -import torch - -from fme.core import registry - - -class MockModule(torch.nn.Module): - def __init__(self, param_shapes: Iterable[Tuple[int, ...]]): - super().__init__() - for i, shape in enumerate(param_shapes): - setattr(self, f"param{i}", torch.nn.Parameter(torch.randn(shape))) - - -@dataclasses.dataclass -class MockModuleBuilder: - param_shapes: Iterable[Tuple[int, ...]] - - def build(self, n_in_channels, n_out_channels, img_shape): - return MockModule(self.param_shapes) - - @classmethod - def from_state(self, state): - return MockModuleBuilder(state["param_shapes"]) - - def get_state(self): - return { - "param_shapes": self.param_shapes, - } - - -def test_register(): - """Make sure that the registry is working as expected.""" - original_registry = registry.NET_REGISTRY - try: - registry.NET_REGISTRY = {} - with pytest.raises(ValueError): - selector = registry.ModuleSelector( - type="mock", - config={"param_shapes": [(1, 2, 3)]}, - ) - registry.register("mock")(MockModuleBuilder) - selector = registry.ModuleSelector( - type="mock", - config={"param_shapes": [(1, 2, 3)]}, - ) - module = selector.build(1, 1, (16, 32)) - assert isinstance(module, MockModule) - finally: - registry.NET_REGISTRY = original_registry diff --git a/fme/fme/core/test_stacker.py b/fme/fme/core/test_stacker.py new file mode 100644 index 0000000..25cf8be --- /dev/null +++ b/fme/fme/core/test_stacker.py @@ -0,0 +1,139 @@ +import pytest +import torch + +from .stacker import Stacker, natural_sort, unstack + + +@pytest.mark.parametrize( + "names, sorted_names", + [ + ( + ["a_1", "b_1", "c_1", "a_2"], + [ + "a_1", + "a_2", + "b_1", + "c_1", + ], + ), + ( + [ + "a_0", + "a_1", + "a_12", + "a_2", + ], + [ + "a_0", + "a_1", + "a_2", + "a_12", + ], + ), + ( + [ + "a_0001", + "a_0012", + "a_0002", + ], + [ + "a_0001", + "a_0002", + "a_0012", + ], + ), + ( + [ + "ab1", + "aa10", + "aa2", + ], + ["aa2", "aa10", "ab1"], + ), + ], +) +def test_natural_sort(names, sorted_names): + assert natural_sort(names) == sorted_names + + +_PREFIX_MAPS = [{"a": ["aa_", "ab_"], "b": ["b_", "bb_"], "c": ["ca", "cb"]}] + +_DATA_NAMES = [ + ["aa_0", "aa_1", "b_00", "b_01", "b", "ca"], + ["ab_0", "ab_1", "bb_1", "bb_2", "cb"], + ["ab_0", "b_", "cc"], +] + + +@pytest.mark.parametrize( + "prefix_map, data_names, expected_level_names", + [ + ( + _PREFIX_MAPS[0], + _DATA_NAMES[0], + {"a": ["aa_0", "aa_1"], "b": ["b_00", "b_01"], "c": ["ca"]}, + ), + ( + _PREFIX_MAPS[0], + _DATA_NAMES[1], + {"a": ["ab_0", "ab_1"], "b": ValueError, "c": ["cb"]}, + ), + ( + _PREFIX_MAPS[0], + _DATA_NAMES[2], + {"a": ["ab_0"], "b": ["b_"], "c": KeyError}, + ), + ], +) +def test_get_all_level_names(prefix_map, data_names, expected_level_names): + data = {name: torch.zeros(2, 2) for name in data_names} + stacker = Stacker(prefix_map) + for standard_name in expected_level_names.keys(): + if isinstance(expected_level_names[standard_name], type): + with pytest.raises(expected_level_names[standard_name]): + stacker.get_all_level_names(standard_name, data) + else: + assert ( + stacker.get_all_level_names(standard_name, data) + == expected_level_names[standard_name] + ) + + +@pytest.mark.parametrize( + "prefix_map, data_names, expected_level_names", + [ + ( + _PREFIX_MAPS[0], + _DATA_NAMES[0], + {"a": ["aa_0", "aa_1"], "b": ["b_00", "b_01"], "c": ["ca"]}, + ), + ( + _PREFIX_MAPS[0], + _DATA_NAMES[1], + {"a": ["ab_0", "ab_1"], "c": ["cb"]}, + ), + ( + _PREFIX_MAPS[0], + _DATA_NAMES[2], + {"a": ["ab_0"], "b": ["b_"]}, + ), + ], +) +def test_stack_unstack(prefix_map, data_names, expected_level_names): + torch.manual_seed(0) + data = {name: torch.rand(2, 2) for name in data_names} + stacker = Stacker(prefix_map) + for standard_name in expected_level_names.keys(): + level_names = expected_level_names[standard_name] + if len(level_names) == 1: + expected_stacked = data[level_names[0]].unsqueeze(-1) + else: + expected_stacked = torch.stack([data[name] for name in level_names], dim=-1) + stacked = stacker(standard_name, data) + assert torch.allclose(stacked, expected_stacked) + unstacked = unstack( + stacked, names=stacker.get_all_level_names(standard_name, data) + ) + for name in level_names: + assert name in unstacked + assert torch.allclose(unstacked[name], data[name]) diff --git a/fme/fme/core/test_timing.py b/fme/fme/core/test_timing.py new file mode 100644 index 0000000..b5e78cf --- /dev/null +++ b/fme/fme/core/test_timing.py @@ -0,0 +1,199 @@ +import logging +import time + +import numpy as np +import pytest + +from fme.core.timing import CumulativeTimer, GlobalTimer + + +def test_CumulativeTimer(): + category = "foo" + cumulative_timer = CumulativeTimer(category) + + cumulative_timer.start() + time.sleep(0.01) + cumulative_timer.stop() + + time.sleep(0.01) + + cumulative_timer.start() + time.sleep(0.01) + cumulative_timer.stop() + + assert cumulative_timer.duration == pytest.approx(0.02, abs=0.005) + + +def test_CumulativeTimer_start_error(): + category = "foo" + cumulative_timer = CumulativeTimer(category) + cumulative_timer.start() + with pytest.raises(RuntimeError, match=f"timer {category!r} is already running"): + cumulative_timer.start() + + +def test_CumulativeTimer_stop_error(): + category = "foo" + cumulative_timer = CumulativeTimer(category) + with pytest.raises( + RuntimeError, match=f"must call start for timer {category!r} before stop" + ): + cumulative_timer.stop() + + +def test_CumulativeTimer_duration_error(): + category = "foo" + cumulative_timer = CumulativeTimer(category) + cumulative_timer.start() + with pytest.raises(RuntimeError, match=f"timer {category!r} is still running"): + cumulative_timer.duration + + +def exercise_active_timer(): + timer = GlobalTimer.get_instance() + + timer.start("foo") + time.sleep(0.01) + timer.stop() + + timer = GlobalTimer.get_instance() + + timer.start("bar") + time.sleep(0.02) + timer.stop() + + assert timer.get_duration("foo") == pytest.approx(0.01, abs=0.005) + assert timer.get_duration("bar") == pytest.approx(0.02, abs=0.005) + + with pytest.raises(KeyError): + timer.get_duration("baz") + + durations = timer.get_durations() + assert durations["foo"] == pytest.approx(0.01, abs=0.005) + assert durations["bar"] == pytest.approx(0.02, abs=0.005) + + +def test_GlobalTimer(): + with GlobalTimer(): + exercise_active_timer() + + # Check that timer is reset within new context. + with GlobalTimer(): + timer = GlobalTimer.get_instance() + assert timer.get_durations() == {} + + +def test_GlobalTimer_resets_after_exception(): + with pytest.raises(ValueError): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + timer.start("foo") + raise ValueError() + + # Check that the context manager clears the state of the timer after an + # exception. If it were not clear, starting the timer for "foo" would raise + # an error. + with GlobalTimer(): + timer = GlobalTimer.get_instance() + timer.start("foo") + timer.stop() + + +def test_GlobalTimer_multiple_context_error(): + with pytest.raises(RuntimeError, match="GlobalTimer is currently in use"): + with GlobalTimer(), GlobalTimer(): + pass + + +def test_inactive_GlobalTimer_warning(): + with pytest.warns(UserWarning, match=r"inactive"): + GlobalTimer.get_instance() + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_inactive_GlobalTimer_start(): + timer = GlobalTimer.get_instance() + timer.start("foo") + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_inactive_GlobalTimer_stop(): + timer = GlobalTimer.get_instance() + timer.stop() + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_inactive_GlobalTimer_get_duration(): + timer = GlobalTimer.get_instance() + result = timer.get_duration("foo") + assert np.isnan(result) + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_inactive_GlobalTimer_get_durations(): + timer = GlobalTimer.get_instance() + result = timer.get_durations() + assert result == {} + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_inactive_GlobalTimer_log_durations(caplog): + timer = GlobalTimer.get_instance() + with caplog.at_level(logging.INFO): + timer.log_durations() + assert len(caplog.records) == 0 + + +@pytest.mark.filterwarnings("ignore:The GlobalTimer") +def test_GlobalTimer_inactive_then_active(): + # Make sure we can instantiate an inactive timer, but then still create and + # use an active timer later. + GlobalTimer.get_instance() + + with GlobalTimer(): + exercise_active_timer() + + +def test_GlobalTimer_context(): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + with timer.context("foo"): + time.sleep(0.01) + assert timer.get_duration("foo") > 0.01 + + +def test_GlobalTimer_context_with_exception(): + with pytest.raises(ValueError): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + with timer.context("foo"): + time.sleep(0.01) + raise ValueError() + assert timer.get_duration("foo") > 0.01 + + +def test_GlobalTimer_single_inner_timer(): + with pytest.raises( + RuntimeError, match="GlobalTimer already has an active inner timer" + ): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + with timer.context("foo"): + with timer.context("bar"): + pass + + +def test_GlobalTimer_nested_outer_context(): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + with timer.outer_context("foo"): + with timer.context("bar"): + pass + + +def test_GlobalTimer_double_nested_outer_context(): + with GlobalTimer(): + timer = GlobalTimer.get_instance() + with timer.outer_context("foo"): + with timer.outer_context("bar"): + pass diff --git a/fme/fme/core/testing/__init__.py b/fme/fme/core/testing/__init__.py index 8de3120..6852c92 100644 --- a/fme/fme/core/testing/__init__.py +++ b/fme/fme/core/testing/__init__.py @@ -1,10 +1,2 @@ from .distributed import mock_distributed -from .fv3gfs_data import ( - DimSizes, - FV3GFSData, - MonthlyReferenceData, - StatsData, - save_2d_netcdf, - save_scalar_netcdf, -) from .wandb import mock_wandb diff --git a/fme/fme/core/testing/wandb.py b/fme/fme/core/testing/wandb.py index 1523a9c..e806854 100644 --- a/fme/fme/core/testing/wandb.py +++ b/fme/fme/core/testing/wandb.py @@ -1,6 +1,6 @@ import collections import contextlib -from typing import Any, Dict, Mapping +from typing import Any, Dict, List, Mapping from fme.core import wandb from fme.core.distributed import Distributed @@ -11,6 +11,7 @@ def __init__(self): self._enabled = False self._configured = False self._logs: Dict[int, Dict[str, Any]] = collections.defaultdict(dict) + self._last_step = 0 def configure(self, log_to_wandb: bool): dist = Distributed.get_instance() @@ -31,12 +32,24 @@ def watch(self, modules): pass def log(self, data: Mapping[str, Any], step: int, sleep=None): + if step < self._last_step: + raise ValueError( + f"step {step} is less than last step {self._last_step}, " + "steps must be logged in order" + ) + self._last_step = step # sleep arg is ignored since we don't want to sleep in tests if self._enabled: self._logs[step].update(data) - def get_logs(self) -> Dict[int, Dict[str, Any]]: - return self._logs + def get_logs(self) -> List[Dict[str, Any]]: + if len(self._logs) == 0: + return [] + n_logs = max(self._logs.keys()) + return_value: List[Dict[str, Any]] = [dict() for _ in range(n_logs + 1)] + for step, log in self._logs.items(): + return_value[step] = log + return return_value def clean_wandb_dir(self, experiment_dir: str): pass diff --git a/fme/fme/core/timing.py b/fme/fme/core/timing.py new file mode 100644 index 0000000..1f0f670 --- /dev/null +++ b/fme/fme/core/timing.py @@ -0,0 +1,168 @@ +import contextlib +import logging +import time +import warnings +from typing import Dict, Optional + +import numpy as np + +singleton: Optional["GlobalTimer"] = None + + +INACTIVE_WARNING_MESSAGE = ( + "The GlobalTimer is currently inactive; therefore no timing information " + "will be recorded. To activate it, wrap your code within the GlobalTimer() " + "context." +) + + +class CumulativeTimer: + def __init__(self, category): + self._duration = 0.0 + self._start_time = None + self._category = category + + def start(self): + if self._start_time is not None: + raise RuntimeError(f"timer {self._category!r} is already running") + self._start_time = time.time() + + def stop(self): + if self._start_time is None: + raise RuntimeError( + f"must call start for timer {self._category!r} before stop" + ) + self._duration += time.time() - self._start_time + self._start_time = None + + @property + def duration(self) -> float: + if self._start_time is not None: + raise RuntimeError(f"timer {self._category!r} is still running") + return self._duration + + +class GlobalTimer: + """ + A singleton class to make timing inference code easier. + """ + + @classmethod + def get_instance(cls) -> "GlobalTimer": + """ + Get the singleton instance of the GlobalTimer class. + """ + global singleton + if singleton is None: + singleton = cls() + warnings.warn(INACTIVE_WARNING_MESSAGE) + return singleton + + @classmethod + def __enter__(cls): + global singleton + if singleton is None: + singleton = cls() + if singleton._active: + raise RuntimeError("GlobalTimer is currently in use in another context") + singleton._active = True + + @classmethod + def __exit__(cls, type, value, traceback): + global singleton + singleton = None + + def __init__(self): + self._timers: Dict[str, CumulativeTimer] = {} + self._active = False + self._current_category: Optional[str] = None + + def outer_context(self, category: str) -> contextlib.AbstractContextManager: + """ + Context manager for timing a block of code. + + May be active at the same time as other timers. + """ + + @contextlib.contextmanager + def timer_context(): + self.start_outer(category) + try: + yield + finally: + self.stop_outer(category) + + return timer_context() + + def context(self, category: str) -> contextlib.AbstractContextManager: + """ + Context manager for timing a block of code. + + Only one inner timer can be active at a time. + """ + + @contextlib.contextmanager + def timer_context(): + self.start(category) + try: + yield + finally: + self.stop() + + return timer_context() + + def start(self, category: str): + """ + Start an inner timer for the given category. + + Only one inner timer can be active at a time. + """ + if self._current_category is not None: + raise RuntimeError( + "GlobalTimer already has an active inner timer, " + f"{self._current_category}" + ) + self.start_outer(category) + self._current_category = category + + def start_outer(self, category: str): + """ + Start a timer for the given category. + + May be active at the same time as other timers. + """ + if self._active: + if category not in self._timers: + self._timers[category] = CumulativeTimer(category) + self._timers[category].start() + + def stop(self): + """ + Stop the currently active inner timer. + """ + if self._current_category is None: + raise RuntimeError("GlobalTimer does not have a running timer") + self.stop_outer(self._current_category) + self._current_category = None + + def stop_outer(self, category: str): + """ + Stop the timer for the given category. + + Does not change the currently active inner timer. + """ + if self._active: + self._timers[category].stop() + + def get_duration(self, category: str) -> float: + if self._active: + return self._timers[category].duration + else: + return np.nan + + def get_durations(self) -> Dict[str, float]: + return {category: timer.duration for category, timer in self._timers.items()} + + def log_durations(self): + for name, duration in self.get_durations().items(): + logging.info(f"{name} duration: {duration:.2f}s") diff --git a/fme/fme/core/typing_.py b/fme/fme/core/typing_.py index dcfac10..140691c 100644 --- a/fme/fme/core/typing_.py +++ b/fme/fme/core/typing_.py @@ -1,6 +1,29 @@ -from typing import Dict, Mapping +import dataclasses +from typing import Dict, Mapping, Optional import torch TensorMapping = Mapping[str, torch.Tensor] TensorDict = Dict[str, torch.Tensor] + + +@dataclasses.dataclass +class Slice: + """ + Configuration of a python `slice` built-in. + + Required because `slice` cannot be initialized directly by dacite. + + Parameters: + start: Start index of the slice. + stop: Stop index of the slice. + step: Step of the slice. + """ + + start: Optional[int] = None + stop: Optional[int] = None + step: Optional[int] = None + + @property + def slice(self) -> slice: + return slice(self.start, self.stop, self.step) diff --git a/fme/fme/core/wandb.py b/fme/fme/core/wandb.py index d7d0d6f..703b8ae 100644 --- a/fme/fme/core/wandb.py +++ b/fme/fme/core/wandb.py @@ -117,12 +117,13 @@ def configure(self, log_to_wandb: bool): self._configured = True def init(self, **kwargs): - """kwargs are passed to wandb.init""" + """Kwargs are passed to wandb.init.""" if not self._configured: raise RuntimeError( "must call WandB.configure before WandB init can be called" ) if self._enabled: + wandb.require("core") wandb.init(**kwargs) def watch(self, modules): diff --git a/fme/fme/core/weight_ops.py b/fme/fme/core/weight_ops.py index d1d1174..8f3ce9f 100644 --- a/fme/fme/core/weight_ops.py +++ b/fme/fme/core/weight_ops.py @@ -25,7 +25,7 @@ class CopyWeightsConfig: All parameters must be covered by either the include or exclude list, but not both. - Attributes: + Parameters: include: list of wildcard patterns to overwrite exclude: list of wildcard patterns to exclude from overwriting """ diff --git a/fme/fme/core/winds.py b/fme/fme/core/winds.py index 78a55d4..1513bcc 100644 --- a/fme/fme/core/winds.py +++ b/fme/fme/core/winds.py @@ -26,7 +26,6 @@ def u_v_to_x_y_z_wind( wy: y wind component wz: z wind component """ - # for a graphical proof of the equations used here, see # https://github.com/ai2cm/full-model/pull/355#issuecomment-1729773301 diff --git a/fme/fme/require_gpu.py b/fme/fme/require_gpu.py new file mode 100644 index 0000000..bac46cb --- /dev/null +++ b/fme/fme/require_gpu.py @@ -0,0 +1,9 @@ +import fme + +""" +Manually triggered for CI tests on GPU so that tests do not +default to CPU if driver issues prevent use of CUDA. +""" +device = str(fme.get_device()) +print(f"Device: {device}") +assert device.startswith("cuda") diff --git a/fme/fme/sht_fix.py b/fme/fme/sht_fix.py index cd5fa6b..2e8da49 100644 --- a/fme/fme/sht_fix.py +++ b/fme/fme/sht_fix.py @@ -10,8 +10,6 @@ [*] https://github.com/NVIDIA/torch-harmonics/blob/17eefa53468d1a885d72087918eba905fa53e10a/torch_harmonics/sht.py """ -USE_FIX = True - # coding=utf-8 @@ -49,12 +47,8 @@ import torch.nn as nn import torch.fft -if USE_FIX: - from torch_harmonics.quadrature import * - from torch_harmonics.legendre import * -else: - from .quadrature import * - from .legendre import * +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import precompute_legpoly, precompute_dlegpoly class RealSHT(nn.Module): @@ -115,10 +109,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho weights = torch.einsum('mlk,k->mlk', pct, weights) # remember quadrature weights - if USE_FIX: - self.weights = weights.float() - else: - self.register_buffer('weights', weights, persistent=False) + self.weights = weights.float() def extra_repr(self): """ @@ -144,8 +135,7 @@ def forward(self, x: torch.Tensor): xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) # contraction - if USE_FIX: - self.weights = self.weights.to(x.device) + self.weights = self.weights.to(x.device) xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights) xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights) x = torch.view_as_complex(xout) @@ -197,10 +187,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) # register buffer - if USE_FIX: - self.pct = pct.float() - else: - self.register_buffer('pct', pct, persistent=False) + self.pct = pct.float() def extra_repr(self): """ @@ -216,8 +203,7 @@ def forward(self, x: torch.Tensor): # Evaluate associated Legendre functions on the output nodes x = torch.view_as_real(x) - if USE_FIX: - self.pct = self.pct.to(x.device) + self.pct = self.pct.to(x.device) rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct ) im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct ) xs = torch.stack((rl, im), -1) @@ -291,10 +277,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho weights[1] = -1 * weights[1] # remember quadrature weights - if USE_FIX: - self.weights = weights.float() - else: - self.register_buffer('weights', weights, persistent=False) + self.weights = weights.float() def extra_repr(self): """ @@ -380,10 +363,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) # register weights - if USE_FIX: - self.dpct = dpct.float() - else: - self.register_buffer('dpct', dpct, persistent=False) + self.dpct = dpct.float() def extra_repr(self): """ @@ -401,8 +381,7 @@ def forward(self, x: torch.Tensor): # contraction - spheroidal component # real component - if USE_FIX: - self.dpct = self.dpct.to(x.device) + self.dpct = self.dpct.to(x.device) srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0]) \ - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1]) # iamg component diff --git a/fme/pyproject.toml b/fme/pyproject.toml index 165f261..0ce45e9 100644 --- a/fme/pyproject.toml +++ b/fme/pyproject.toml @@ -33,3 +33,22 @@ optional-dependencies.deploy = { file = "deploy-requirements.txt" } [tool.setuptools.packages] find = {} + +[tool.uv] +cache-keys = [ + { file = "requirements.txt" }, + { file = "dev-requirements.txt" }, + { file = "docs/requirements.txt" }, +] + +[tool.ruff.lint] +select = ["D", "E", "F", "I", "W"] +ignore = ["D1", "D200", "D205", "D212", "E203", "W293", "F541", "E402"] + +[tool.ruff.lint.per-file-ignores] +"*/__init__.py" = ["F401"] +"scripts/*" = ["D"] +"test_*.py" = ["D"] + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/fme/pytest.ini b/fme/pytest.ini deleted file mode 100644 index 4c45e1d..0000000 --- a/fme/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -markers = - requires_gpu: these tests require a GPU to run \ No newline at end of file diff --git a/fme/requirements.txt b/fme/requirements.txt index 587ce0b..e1ed2a7 100644 --- a/fme/requirements.txt +++ b/fme/requirements.txt @@ -1,9 +1,9 @@ h5py imageio<=2.27.0 -moviepy +moviepy<2.0.0 # should be able to relax this after wandb updates past 0.18.7 netcdf4 numpy<2 -wandb +wandb[media] tensorly tensorly-torch xarray @@ -17,4 +17,4 @@ plotly matplotlib dask astropy-healpix -pandas +pandas \ No newline at end of file diff --git a/scripts/data_process/Makefile b/scripts/data_process/Makefile new file mode 100644 index 0000000..069b50f --- /dev/null +++ b/scripts/data_process/Makefile @@ -0,0 +1,246 @@ +# The dependencies of the scripts below (where not containerized) are installed in the "fv3net" conda environment +# which can be installed using fv3net's Makefile. See +# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310 + +# Some data has been put to coldline storage making it expensive +# to access. In order to run processing on that data, set this +# variable to true. +ENABLE_COLDLINE ?= false + +RESOLUTION ?= 1deg +LAYERS ?= 8layer + +ROOT_BASE_CLIMSST_1DEG = gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360 +ROOT_BASE_CLIMSST_4DEG = gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90 +ROOT_BASE_AMIP_1DEG = gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360 +ROOT_BASE_AMIP_4DEG = gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90 +ROOT_INTERMEDIATE_CLIMSST_1DEG = gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset +ROOT_INTERMEDIATE_CLIMSST_4DEG = gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset +ROOT_INTERMEDIATE_AMIP_1DEG = gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset +ROOT_INTERMEDIATE_AMIP_4DEG = gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset +OUTPUT_DIR_CLIMSST_1DEG = /net/nfs/climate/data/2023-08-11-vertically-resolved-1deg-fme-ensemble-dataset +OUTPUT_DIR_CLIMSST_4DEG = /net/nfs/climate/data/2023-08-11-vertically-resolved-4deg-fme-ensemble-dataset +OUTPUT_DIR_AMIP_1DEG = /pscratch/sd/b/bhenn/data/2023-11-01-vertically-resolved-1deg-fme-amip-ensemble-dataset +OUTPUT_DIR_AMIP_4DEG = /pscratch/sd/b/bhenn/data/2023-11-01-vertically-resolved-4deg-fme-amip-ensemble-dataset +NAME_CLIMSST_1DEG = fv3gfs-ensemble +NAME_CLIMSST_4DEG = fv3gfs-ensemble-4deg +NAME_AMIP_1DEG = fv3gfs-AMIP-ensemble +NAME_AMIP_4DEG = fv3gfs-AMIP-ensemble-4deg + +ifeq ($(RESOLUTION),1deg) + ROOT_INTERMEDIATE ?= $(ROOT_INTERMEDIATE_CLIMSST_1DEG) + ROOT_INTERMEDIATE_AMIP ?= $(ROOT_INTERMEDIATE_AMIP_1DEG) + ROOT_BASE ?= $(ROOT_BASE_CLIMSST_1DEG) + ROOT_AMIP ?= $(ROOT_BASE_AMIP_1DEG) + OUTPUT_DIR ?= $(OUTPUT_DIR_CLIMSST_1DEG) + OUTPUT_DIR_AMIP ?= $(OUTPUT_DIR_AMIP_1DEG) + NAME ?= $(NAME_CLIMSST_1DEG) + NAME_AMIP ?= $(NAME_AMIP_1DEG) +else ifeq ($(RESOLUTION),4deg) + ROOT_INTERMEDIATE ?= $(ROOT_INTERMEDIATE_CLIMSST_4DEG) + ROOT_INTERMEDIATE_AMIP ?= $(ROOT_INTERMEDIATE_AMIP_4DEG) + ROOT_BASE ?= $(ROOT_BASE_CLIMSST_4DEG) + ROOT_AMIP ?= $(ROOT_BASE_AMIP_4DEG) + OUTPUT_DIR ?= $(OUTPUT_DIR_CLIMSST_4DEG) + OUTPUT_DIR_AMIP ?= $(OUTPUT_DIR_AMIP_4DEG) + NAME ?= $(NAME_CLIMSST_4DEG) + NAME_AMIP ?= $(NAME_AMIP_4DEG) +endif + +# the netCDF generation step is done locally usually, so user sets input and output directories +NC_INPUT ?= +NC_OUTPUT ?= + +SHIELD_RES ?= c96 + +.PHONY: shield_AMIP_dataset +shield_AMIP_dataset: + ./compute_dataset.sh \ + --config configs/shield-amip-ensemble-$(SHIELD_RES)-$(RESOLUTION)-8layer.yaml + +.PHONY: shield_AMIP_monthly_netcdfs +shield_AMIP_monthly_netcdfs: + ./convert_to_monthly_netcdf_fv3gfs.sh \ + --input-url $(NC_INPUT) \ + --n-ic 2 \ + --output-dir $(NC_OUTPUT) \ + --start-date 1940-01-01 \ + --end-date 2021-12-31 + +.PHONY: shield_c24_4deg_climSST_dataset +shield_c24_4deg_climSST_dataset: + ./compute_dataset.sh --config configs/shield-c24-ensemble-4deg-8layer.yaml + +.PHONY: shield_c24_4deg_climSST_monthly_netcdfs +shield_c24_4deg_climSST_monthly_netcdfs: + ./convert_to_monthly_netcdf_fv3gfs.sh \ + --input-url gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset \ + --n-ic 21 \ + --output-dir $(NC_OUTPUT) \ + --start-date 2021-01-01 \ + --end-date 2030-12-31 + +shield_c24_4deg_climSST_stats_beaker_dataset: + ./compute_stats.sh --config configs/shield-c24-ensemble-4deg-8layer.yaml + +.PHONY: shield_c96_4deg_climSST_dataset +shield_c96_4deg_climSST_dataset: + ./compute_dataset.sh --config configs/shield-c96-4deg-8layer.yaml + +.PHONY: shield_c96_4deg_climSST_monthly_netcdfs +shield_c96_4deg_climSST_monthly_netcdfs: + ./convert_to_monthly_netcdf_fv3gfs.sh \ + --input-url gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset \ + --n-ic 1 \ + --output-dir $(NC_OUTPUT) \ + --start-date 2035-01-01 \ + --end-date 2060-12-31 + +shield_c96_4deg_climSST_stats_beaker_dataset: + ./compute_stats.sh --config configs/shield-c96-4deg-8layer.yaml + +.PHONY: fv3gfs_1deg_climSST_dataset +fv3gfs_1deg_climSST_dataset: + $(MAKE) fv3gfs_climSST_dataset RESOLUTION=1deg LAYERS=8layer + +.PHONY: fv3gfs_climSST_dataset +fv3gfs_climSST_dataset: + ./compute_dataset.sh --config configs/fv3gfs-ensemble-$(RESOLUTION)-$(LAYERS).yaml + +.PHONY: fv3gfs_1deg_climSST_monthly_netcdfs +fv3gfs_1deg_climSST_monthly_netcdfs: + $(MAKE) fv3gfs_climSST_monthly_netcdfs RESOLUTION=1deg + +.PHONY: enable_coldline_check +enable_coldline_check: + @if [ "$(ENABLE_COLDLINE)" != "true" ]; then \ + echo "Processing target is deprecated due to coldlined data" \ + echo "to run, ENABLE_COLDLINE must be set to true. Exiting."; \ + exit 1; \ + fi + +.PHONY: fv3gfs_climSST_monthly_netcdfs +fv3gfs_climSST_monthly_netcdfs: + ./convert_to_monthly_netcdf_fv3gfs.sh \ + --input-url $(ROOT_INTERMEDIATE) \ + --n-ic 11 \ + --output-dir $(OUTPUT_DIR) \ + --start-date 2021-01-01 \ + --end-date 2030-12-31 + +.PHONY: fv3gfs_1deg_climSST_stats_beaker_dataset +fv3gfs_1deg_climSST_stats_beaker_dataset: + $(MAKE) fv3gfs_climSST_stats_beaker_dataset RESOLUTION=1deg + +fv3gfs_climSST_stats_beaker_dataset: enable_coldline_check + ./compute_stats.sh --config configs/fv3gfs-ensemble-$(RESOLUTION)-$(LAYERS).yaml + +# This took around ~10 hours to complete. The roundtrip_filter adds significant computational time. +# If we plan to do this regularly, paralellizing the script across time using xpartition to +# launch jobs for different time chunks would probably be a good idea. +fv3gfs_climSST_c48_baseline_dataset: enable_coldline_check + ./compute_dataset.sh --config configs/fv3gfs-c48-ensemble-1deg-8layer.yaml + +.PHONY: fv3gfs_climSST_c48_baseline_monthly_netcdfs +fv3gfs_climSST_c48_baseline_monthly_netcdfs: + python convert_to_monthly_netcdf.py \ + --prepend-nans \ + gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset/ic_0011.zarr \ + /net/nfs/climate/data/2023-09-12-vertically-resolved-1deg-fme-c48-baseline-dataset-truncated-065/ic_0011 \ + --start-date 2021-01-01 \ + --end-date 2030-12-31 + +.PHONY: fv3gfs_1deg_AMIP_dataset +fv3gfs_1deg_AMIP_dataset: + $(MAKE) fv3gfs_AMIP_dataset RESOLUTION=1deg LAYERS=8layer + +fv3gfs_AMIP_dataset: enable_coldline_check + ./compute_dataset.sh --config configs/fv3gfs-amip-ensemble-$(RESOLUTION)-$(LAYERS).yaml + +.PHONY: fv3gfs_1deg_AMIP_monthly_netcdfs +fv3gfs_1deg_AMIP_monthly_netcdfs: + $(MAKE) fv3gfs_AMIP_monthly_netcdfs RESOLUTION=1deg + +.PHONY: fv3gfs_AMIP_monthly_netcdfs +fv3gfs_AMIP_monthly_netcdfs: + ./convert_to_monthly_netcdf_fv3gfs.sh \ + --input-url $(ROOT_INTERMEDIATE_AMIP) \ + --n-ic 4 \ + --output-dir $(OUTPUT_DIR_AMIP) \ + --start-date 1990-01-01 \ + --end-date 2019-12-31 + +.PHONY: fv3gfs_1deg_AMIP_stats_beaker_dataset +fv3gfs_1deg_AMIP_stats_beaker_dataset: + $(MAKE) fv3gfs_AMIP_stats_beaker_dataset RESOLUTION=1deg + +fv3gfs_AMIP_stats_beaker_dataset: enable_coldline_check + ./compute_stats.sh --config configs/fv3gfs-amip-ensemble-$(RESOLUTION)-$(LAYERS).yaml + +# TODO: Add AMIP baseline C48 dataset processing when available + +.PHONY: shield_som_spin_up_c96_dataset +shield_som_c96_spin_up_dataset: + ./compute_dataset.sh --config configs/shield-som-spin-up-c96-1deg-$(LAYERS).yaml + +.PHONY: shield_som_ensemble_c96_dataset +shield_som_ensemble_c96_dataset: + ./compute_dataset.sh --config configs/shield-som-ensemble-c96-$(RESOLUTION)-$(LAYERS).yaml + +.PHONY: shield_som_abrupt_co2_increase_c96_dataset +shield_som_abrupt_co2_increase_c96_dataset: + ./compute_dataset.sh --config configs/shield-som-abrupt-co2-increase-c96-$(RESOLUTION)-$(LAYERS).yaml + +.PHONY: shield_som_increasing_co2_c96_dataset +shield_som_c96_increasing_co2_dataset: + ./compute_dataset.sh --config configs/shield-som-increasing-co2-c96-$(RESOLUTION)-$(LAYERS).yaml + +.PHONY: shield_som_c96_radiation_multi_call_dataset +shield_som_c96_radiation_multi_call_dataset: + ./compute_dataset.sh --config configs/shield-som-radiation-multi-call-c96-1deg-$(LAYERS).yaml + +.PHONY: shield_som_c24_dataset +shield_som_c24_dataset: + ./compute_dataset.sh --config configs/shield-som-c24-4deg-$(LAYERS).yaml + +.PHONY: shield_som_c24_tuned_cdmbgwd_dataset +shield_som_c24_tuned_cdmbgwd_dataset: + ./compute_dataset.sh --config configs/shield-som-c24-tuned-cdmbgwd-4deg-$(LAYERS).yaml + +# In total (not including full_field stats), this took 7 hours 15 minutes to +# complete using a single Perlmutter CPU node. + +.PHONY: e3smv2_1deg_climSST_dataset +e3smv2_1deg_climSST_dataset: + sbatch -J 2024-07-10-e3smv2-1deg-testing generate_datasets_e3smv2.sh \ + --input-dir /global/cfs/cdirs/m4492/rebassoo/e3sm_test/post/atm/180x360_gaussian/ts \ + --config configs/e3sm-1deg-8layer.yaml \ + --zarr /global/cfs/cdirs/m4492/fme-preprocess/zarr/2024-07-10-e3smv2-1deg-testing.zarr \ + --output-dir /global/cfs/cdirs/m4492/fme-preprocess/2024-07-10-e3smv2-1deg-testing + +.PHONY: era5_1deg_stats_beaker_dataset +era5_1deg_stats_beaker_dataset: + ./compute_stats.sh --config configs/era5-1deg-8layer-1940-2022.yaml + +.PHONY: era5_1deg_16layer_stats_beaker_dataset +era5_1deg_16layer_stats_beaker_dataset: + ./compute_stats.sh --config configs/era5-1deg-16layer-1940-2022.yaml + +.PHONY: compute_cm4_trial_run_atmosphere_dataset +compute_cm4_trial_run_atmosphere_dataset: + python -u compute_dataset.py \ + --config configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml \ + --run-directory gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/regridded-zarrs/gaussian_grid_180_by_360/trial-run \ + --output-store gs://vcm-ml-intermediate/2024-09-20-cm4-1deg-8layer-trial-run.zarr + +.PHONY: compute_cm4_trial_run_atmosphere_stats +compute_cm4_trial_run_atmosphere_stats: + python -u get_stats.py configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml 0 + +.PHONY: build_docker_image +build_docker_image: + docker build -f beakerpy.Dockerfile -t us.gcr.io/vcm-ml/beaker-py . + +.PHONY: push_docker_image +push_docker_image: build_docker_image + docker push us.gcr.io/vcm-ml/beaker-py diff --git a/scripts/data_process/README.md b/scripts/data_process/README.md index 40dc8ff..4c1706f 100644 --- a/scripts/data_process/README.md +++ b/scripts/data_process/README.md @@ -1,3 +1,28 @@ # Data processing for full model emulation training -This directory contains scripts for generating various datasets needed for FME training \ No newline at end of file +This directory contains scripts for generating various datasets needed for FME training, including the FV3GFS primary, baseline, and stats datasets. + +It also contains scripts for generating E3SM training data. + +The first step in the process to create intermediate datasets (e.g. `make fv3gfs_AMIP_dataset`) uses argo, and can be run on your Google VM. +See the vcm-workflow-control repo for instructions on how to install and run argo. + +The second step, which produces monthly netCDF files locally (e.g. `make fv3gfs_AMIP_monthly_netcdfs`), can be run on cirrascale in an interactive session. +To create an interactive session, run the following command from the `scripts/data_process` directory: + +``` +beaker session create --budget ai2/climate --image beaker://jeremym/fme-2bc0033e --gpus 0 --mount hostPath:///net/nfs/climate=/net/nfs/climate --mount hostpath://$(pwd)=/full-model --workdir /full-model/scripts/data_process +``` + +Doing so will require that your current working directory is a mountable path (e.g. something in /data). +If you'd like to write to a different directory than /net/nfs/climate, you can mount that path instead. + +Once inside the image, you will need to authorize access to GCS by running `gcloud auth application-default login` and following the instructions, including to run `gcloud config set project vcm-ml` afterwards. + +You can then produce the monthly netCDFs in a target directory by modifying the `OUTPUT_DIR` or `OUTPUT_DIR_AMIP` variable in the make command below. + +``` +make fv3gfs_AMIP_monthly_netcdfs RESOLUTION=4deg OUTPUT_DIR_AMIP=/data/shared/2023-12-20-vertically-resolved-4deg-fme-amip-ensemble-dataset +``` + +The stats dataset creation step (e.g. `make fv3gfs_AMIP_stats_beaker_dataset`) must be run in the fme conda environment (created by `make create_environment` at the top level of this repo), and additionally requires the beaker client is installed ([install instructions](https://beaker-docs.apps.allenai.org/start/install.html)). diff --git a/scripts/data_process/beakerpy.Dockerfile b/scripts/data_process/beakerpy.Dockerfile new file mode 100644 index 0000000..92e7aa7 --- /dev/null +++ b/scripts/data_process/beakerpy.Dockerfile @@ -0,0 +1,3 @@ +from python:3.10-slim + +RUN pip install beaker-py==1.30.0 click dacite fsspec==2024.6.1 gcsfs==2024.6.1 diff --git a/scripts/data_process/combine_stats.py b/scripts/data_process/combine_stats.py index c398cb7..d86d565 100644 --- a/scripts/data_process/combine_stats.py +++ b/scripts/data_process/combine_stats.py @@ -36,7 +36,7 @@ class Config: def add_history_attrs(ds, config_filename: str, stats_output_dir: str): ds.attrs["history"] = ( - "Created by ace/fv3gfs_data_process/combine_stats.py from " + "Created by full-model/fv3gfs_data_process/combine_stats.py from " f"configuration file {config_filename} using inputs at {stats_output_dir}." ) diff --git a/scripts/data_process/compute_dataset.py b/scripts/data_process/compute_dataset.py index 4a948a4..7773eb3 100755 --- a/scripts/data_process/compute_dataset.py +++ b/scripts/data_process/compute_dataset.py @@ -10,20 +10,25 @@ # for a workflow which parallelizes this script across the 11-member ensemble and runs # it on our GKE cluster. +import abc import dataclasses import os -from typing import List, Mapping, MutableMapping, Optional, Sequence, Tuple +import sys +from typing import Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import click import dacite import fsspec import numpy as np import xarray as xr -import xpartition # noqa: 401 +import xpartition # noqa: F401 import xtorch_harmonics import yaml from dask.diagnostics import ProgressBar +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from get_stats import StatsConfig + # constants are defined as in FV3GFS model # https://github.com/ai2cm/fv3gfs-fortran/blob/master/FMS/constants/constants.F90 LATENT_HEAT_OF_VAPORIZATION = 2.5e6 # J/kg @@ -41,6 +46,7 @@ class StandardNameMapping: longitude_dim: str = "grid_xt" latitude_dim: str = "grid_yt" vertical_dim: str = "pfull" + vertical_interface_dim: str = "phalf" time_dim: str = "time" surface_pressure: str = "PRESsfc" latent_heat_flux: str = "LHTFLsfc" @@ -56,6 +62,11 @@ class StandardNameMapping: snow_mixing_ratio: str = "snow_mixing_ratio" northward_wind: str = "northward_wind" eastward_wind: str = "eastward_wind" + surface_evaporation_rate: str = "surface_evaporation_rate" + land_fraction: str = "land_fraction" + ocean_fraction: str = "ocean_fraction" + sea_ice_fraction: str = "sea_ice_fraction" + hybrid_level_coeffs: List[str] = dataclasses.field(default_factory=list) def __post_init__(self): self.horizontal_dims: List[str] = [self.longitude_dim, self.latitude_dim] @@ -65,15 +76,6 @@ def __post_init__(self): self.pwat_tendency = f"tendency_of_{self.total_water_path}" self.time_derivative_names = [self.total_water_path] - self.water_species: List[str] = [ - self.specific_humidity, - self.cloud_water_mixing_ratio, - self.cloud_ice_mixing_ratio, - self.graupel_mixing_ratio, - self.rain_mixing_ratio, - self.snow_mixing_ratio, - ] - self.vertically_resolved: List[str] = [ self.specific_total_water, self.air_temperature, @@ -85,17 +87,60 @@ def __post_init__(self): self.dropped_variables: List[str] = ( self.water_species + self.vertically_resolved - + [self.pressure_thickness, self.vertical_dim, self.precipitable_water_path] + + [self.pressure_thickness, self.vertical_dim] ) + if self.precipitable_water_path.lower() != "none": + self.dropped_variables.append(self.precipitable_water_path) + + @property + def water_species(self) -> List[str]: + return [ + item + for item in [ + self.specific_humidity, + self.cloud_water_mixing_ratio, + self.cloud_ice_mixing_ratio, + self.graupel_mixing_ratio, + self.rain_mixing_ratio, + self.snow_mixing_ratio, + ] + if item.lower() != "none" + ] + + +@dataclasses.dataclass +class DLWPNameMapping(StandardNameMapping): + longitude_dim: str = "longitude" + latitude_dim: str = "latitude" + face_dim: str = "face" + width_dim: str = "width" + height_dim: str = "height" + + def __post_init__(self): + super().__post_init__() + + self.horizontal_dims: List[str] = [ + self.face_dim, + self.width_dim, + self.height_dim, + ] + self.lat_lon_dims: List[str] = [self.latitude_dim, self.longitude_dim] @dataclasses.dataclass -class ChunkingConfig: +class _ChunkingConfig(abc.ABC): time_dim: int = 160 + + @abc.abstractmethod + def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]: ... + + +@dataclasses.dataclass +class ChunkingConfig(_ChunkingConfig): latitude_dim: int = -1 longitude_dim: int = -1 - def get_chunks(self, standard_names: StandardNameMapping): + def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]: return { standard_names.time_dim: self.time_dim, standard_names.longitude_dim: self.longitude_dim, @@ -103,11 +148,32 @@ def get_chunks(self, standard_names: StandardNameMapping): } +@dataclasses.dataclass +class DLWPChunkingConfig(_ChunkingConfig): + face_dim: int = -1 + width_dim: int = -1 + height_dim: int = -1 + + def get_chunks(self, standard_names: StandardNameMapping) -> Dict[str, int]: + dlwp_names = standard_names + if not isinstance(dlwp_names, DLWPNameMapping): + raise TypeError( + "Expected DLWPChunkingConfig to be passed type of DLWPNameMapping." + ) + chunks = { + dlwp_names.time_dim: self.time_dim, + dlwp_names.face_dim: self.face_dim, + dlwp_names.width_dim: self.width_dim, + dlwp_names.height_dim: self.height_dim, + } + return chunks + + @dataclasses.dataclass class DatasetComputationConfig: """Configuration of computation details for an FME reference dataset. - Attributes: + Parameters: reference_vertical_coordinate_file: path to netCDF file containing vertical coordinate definition for the reference simulation. vertical_coarsening_indices: list of tuples defining the ranges of @@ -123,6 +189,8 @@ class DatasetComputationConfig: names of variables in the dataset. chunking: (optional) mapping of standard dimension names to desired output chunk sizes + time_invariant_dir: (optional) path to directory containing time-invariant data + This option is used for E3SMv2 dataset. """ reference_vertical_coordinate_file: str @@ -131,24 +199,31 @@ class DatasetComputationConfig: n_split: int = 65 renaming: Mapping[str, str] = dataclasses.field(default_factory=dict) roundtrip_fraction_kept: Optional[float] = None - standard_names: StandardNameMapping = StandardNameMapping() - chunking: ChunkingConfig = ChunkingConfig() + standard_names: Union[StandardNameMapping, DLWPNameMapping] = dataclasses.field( + default_factory=StandardNameMapping + ) + chunking: Union[ChunkingConfig, DLWPChunkingConfig] = dataclasses.field( + default_factory=ChunkingConfig + ) + time_invariant_dir: Optional[str] = None @dataclasses.dataclass class DatasetConfig: """Dataset provenance for a set of reference simulations. - Attributes: + Parameters: runs: mapping of short names to full paths of reference datasets. output_directory: path to place output of computation script. dataset_computation: configuration details for dataset computation. + stats_config: configuration to retrieve statistics dataset """ runs: Mapping[str, str] data_output_directory: str dataset_computation: DatasetComputationConfig + stats: StatsConfig @classmethod def from_file(cls, path: str) -> "DatasetConfig": @@ -156,7 +231,7 @@ def from_file(cls, path: str) -> "DatasetConfig": data = yaml.safe_load(file) return dacite.from_dict( - data_class=cls, data=data, config=dacite.Config(cast=[tuple]) + data_class=cls, data=data, config=dacite.Config(cast=[tuple], strict=True) ) @@ -222,16 +297,90 @@ def get_coarse_ak_bk( return xr.Dataset(data) +def compute_ocean_fraction( + ds: xr.Dataset, + output_name: str, + land_fraction_name: str, + sea_ice_fraction_name: str, +) -> xr.Dataset: + """Compute latent heat flux, if needed.""" + if output_name in ds.data_vars: + # if ocean_fraction is already computed, assume that NaNs have been handled + return ds + ds[sea_ice_fraction_name] = ds[sea_ice_fraction_name].fillna(0.0) + ocean_fraction = 1 - ds[sea_ice_fraction_name] - ds[land_fraction_name] + negative_ocean = xr.where(ocean_fraction < 0, ocean_fraction, 0) + ocean_fraction -= negative_ocean + ds["sea_ice_fraction"] += negative_ocean + ocean_fraction.attrs["units"] = "unitless" + ocean_fraction.attrs["long_name"] = "fraction of grid cell area occupied by ocean" + return ds.assign({output_name: ocean_fraction}) + + +def compute_latent_heat_flux( + ds: xr.Dataset, + output_name: str, + evaporation_name: Optional[str] = None, +) -> xr.Dataset: + """Compute latent heat flux, if needed.""" + if output_name in ds.data_vars: + return ds + assert ( + evaporation_name is not None + ), f"{output_name} not found in ds, evaporation_name must be provided." + latent_heat_flux = ds[evaporation_name] * LATENT_HEAT_OF_VAPORIZATION + latent_heat_flux.attrs["units"] = "W/m^2" + latent_heat_flux.attrs["long_name"] = "Latent heat flux" + return ds.assign({output_name: latent_heat_flux}).drop(evaporation_name) + + def compute_specific_total_water( ds: xr.Dataset, water_condensate_names: Sequence[str], output_name: str ) -> xr.Dataset: """Compute specific total water from individual water species.""" - specific_total_water = sum([ds[name] for name in water_condensate_names]) + specific_total_water: xr.DataArray = sum( + [ds[name] for name in water_condensate_names] + ) specific_total_water.attrs["units"] = "kg/kg" specific_total_water.attrs["long_name"] = output_name.replace("_", " ") return ds.assign({output_name: specific_total_water}) +def compute_pressure_thickness( + ds: xr.Dataset, + vertical_coordinate_file: str, + vertical_dim_name: str, + surface_pressure_name: str, + output_name: str, + z_dim: str = "xaxis_1", +): + if output_name in ds.data_vars: + return ds + + with fsspec.open(vertical_coordinate_file) as f: + vertical_coordinate = xr.open_dataset(f).load() + # squeeze out the singleton time dimension + vertical_coord = vertical_coordinate.squeeze(drop=True) + + sfc_pressure = ds[surface_pressure_name].expand_dims( + {z_dim: vertical_coord[z_dim]}, axis=3 + ) + phalf = sfc_pressure * vertical_coord["bk"] + vertical_coord["ak"] + + thickness = ( + phalf.diff(dim=z_dim) + .rename({z_dim: vertical_dim_name}) + .rename(output_name) + .assign_coords( + {vertical_dim_name: (vertical_dim_name, ds[vertical_dim_name].values)} + ) + ) + + thickness.attrs["units"] = "Pa" + thickness.attrs["long_name"] = output_name.replace("_", " ") + return ds.assign({output_name: thickness}) + + def compute_vertical_coarsening( ds: xr.Dataset, vertically_resolved_names: Sequence[str], @@ -325,8 +474,12 @@ def assert_global_dry_air_mass_conservation( column_dry_air_mass = ( ds[surface_pressure_name] - ds[total_water_path_name] * GRAVITY ) - weights = np.cos(np.deg2rad(ds[latitude_dim])) - global_dry_air_mass = column_dry_air_mass.weighted(weights).mean(dim=dims) + if latitude_dim in dims: + weights = np.cos(np.deg2rad(ds[latitude_dim])) + global_dry_air_mass = column_dry_air_mass.weighted(weights).mean(dim=dims) + else: + global_dry_air_mass = column_dry_air_mass.mean(dim=dims) + global_dry_air_mass_tendency = global_dry_air_mass.diff(time_dim) print("Mean absolute global dry air pressure tendency [Pa]:") print(np.abs(global_dry_air_mass_tendency).mean().values) @@ -385,11 +538,29 @@ def construct_lazy_dataset( lon_dim=standard_names.longitude_dim, fraction_modes_kept=config.roundtrip_fraction_kept, ) + ds = compute_ocean_fraction( + ds, + output_name=standard_names.ocean_fraction, + land_fraction_name=standard_names.land_fraction, + sea_ice_fraction_name=standard_names.sea_ice_fraction, + ) + ds = compute_latent_heat_flux( + ds, + output_name=standard_names.latent_heat_flux, + evaporation_name=standard_names.surface_evaporation_rate, + ) ds = compute_specific_total_water( ds, water_condensate_names=standard_names.water_species, output_name=standard_names.specific_total_water, ) + ds = compute_pressure_thickness( + ds, + vertical_coordinate_file=config.reference_vertical_coordinate_file, + vertical_dim_name=standard_names.vertical_dim, + surface_pressure_name=standard_names.surface_pressure, + output_name=standard_names.pressure_thickness, + ) ds = compute_vertical_coarsening( ds, vertically_resolved_names=standard_names.vertically_resolved, @@ -424,7 +595,7 @@ def construct_lazy_dataset( chunks = config.chunking.get_chunks(standard_names) ds = ds.chunk(chunks) ds.attrs["history"] = ( - "Dataset computed by ace/scripts/data_process" + "Dataset computed by full-model/scripts/data_process" "/compute_dataset_fv3gfs.py" f" script, using following input zarrs: {urls.values()}." ) @@ -473,6 +644,7 @@ def main( surface_pressure_name=standard_names.surface_pressure, total_water_path_name=standard_names.total_water_path, latitude_dim=standard_names.latitude_dim, + time_dim=standard_names.time_dim, ) assert_global_moisture_conservation( ds, @@ -482,6 +654,7 @@ def main( latent_heat_flux_name=standard_names.latent_heat_flux, latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION, precip_rate_name=standard_names.precip_rate, + time_dim=standard_names.time_dim, ) ds = ds.drop(standard_names.dropped_variables) print(f"Output dataset size is {ds.nbytes / 1e9} GB") diff --git a/scripts/data_process/compute_dataset.sh b/scripts/data_process/compute_dataset.sh new file mode 100755 index 0000000..619c629 --- /dev/null +++ b/scripts/data_process/compute_dataset.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +set -e + +COMPUTE_DATASET=true + +while [[ "$#" -gt 0 ]] +do case $1 in + --config) CONFIG="$2" + shift;; + --stats-only) COMPUTE_DATASET=false;; + *) echo "Unknown parameter passed: $1" + exit 1;; +esac +shift +done + +if [[ -z "${CONFIG}" ]] +then + echo "Option --config missing" + exit 1; +fi + +names=($(yq -r '.runs | to_entries[].key' ${CONFIG})) +run_directories=($(yq -r '.runs | to_entries[].value' ${CONFIG})) +output_directory=$(yq -r '.data_output_directory' ${CONFIG}) +runs_count=$(yq -r '.runs | length' ${CONFIG}) +runs_count_minus_one=$(($runs_count - 1)) + +# Capture the output of the argo submit command +output=$(argo submit compute_dataset_argo_workflow.yaml \ + -p compute_dataset=${COMPUTE_DATASET} \ + -p python_script="$(< compute_dataset.py)" \ + -p get_stats_script="$(< get_stats.py)" \ + -p combine_stats_script="$(< combine_stats.py)" \ + -p upload_stats_script="$(< upload_stats.py)" \ + -p config="$(< ${CONFIG})" \ + -p names="${names[*]}" \ + -p run_directories="${run_directories[*]}" \ + -p output_directory="${output_directory}" \ + -p runs_count_minus_one=${runs_count_minus_one}) + +# Extract the job name from the output +job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}') + +# Print the job name +echo "Argo job submitted: $job_name" diff --git a/scripts/data_process/compute_dataset_argo_workflow.yaml b/scripts/data_process/compute_dataset_argo_workflow.yaml new file mode 100644 index 0000000..9c05fb0 --- /dev/null +++ b/scripts/data_process/compute_dataset_argo_workflow.yaml @@ -0,0 +1,282 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: compute-fme-dataset-ensemble- +spec: + entrypoint: compute-fme-dataset-ensemble + volumes: + - name: gcp-key-secret + secret: + defaultMode: 420 + secretName: gcp-key + arguments: + parameters: + - name: python_script + - name: get_stats_script + - name: combine_stats_script + - name: upload_stats_script + - name: config + - name: names + - name: run_directories + - name: output_directory + - name: runs_count_minus_one + - name: compute_dataset + value: "true" # default value, can be overridden + templates: + - name: compute-fme-dataset-ensemble + steps: + - - name: compute-fme-dataset-individual + when: "{{workflow.parameters.compute_dataset}} == true" + template: compute-fme-dataset-individual + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.python_script}}" + - name: stats_script + value: "{{workflow.parameters.get_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - name: names + value: "{{workflow.parameters.names}}" + - name: run_directories + value: "{{workflow.parameters.run_directories}}" + - name: output_directory + value: "{{workflow.parameters.output_directory}}" + - name: run + value: "{{item}}" + withSequence: + start: "0" + end: "{{workflow.parameters.runs_count_minus_one}}" + - - name: get-stats + template: get-stats + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.get_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - name: run + value: "{{item}}" + withSequence: + start: "0" + end: "{{workflow.parameters.runs_count_minus_one}}" + - - name: combine-stats + template: combine-stats + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.combine_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - - name: upload-beaker-stats + template: upload-beaker-stats + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.upload_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - name: compute-fme-dataset-individual + tolerations: + - effect: NoSchedule + key: dedicated + value: med-sim-pool + inputs: + parameters: + - name: python_script + - name: stats_script + - name: config + - name: names + - name: run_directories + - name: output_directory + - name: run + container: + image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2 + resources: + limits: + cpu: "8000m" + memory: "48Gi" + requests: + cpu: "7500m" + memory: "48Gi" + command: ["bash", "-c", "-e"] + args: + - | + cat << EOF > script.py + {{inputs.parameters.python_script}} + EOF + + cat << EOF > get_stats.py + {{inputs.parameters.stats_script}} + EOF + + cat << EOF > config.yaml + {{inputs.parameters.config}} + EOF + + run={{inputs.parameters.run}} + names=({{inputs.parameters.names}}) + run_directories=({{inputs.parameters.run_directories}}) + output_directory={{inputs.parameters.output_directory}} + + name="${names[${run}]}" + run_directory="${run_directories[${run}]}" + + output_store=${output_directory}/${name}.zarr + + python script.py --config config.yaml \ + --run-directory ${run_directory} \ + --output-store ${output_store} + env: + - name: GOOGLE_APPLICATION_CREDENTIALS + value: /secret/gcp-credentials/key.json + - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE + value: /secret/gcp-credentials/key.json + volumeMounts: + - mountPath: /secret/gcp-credentials + name: gcp-key-secret + - name: compute-fme-dataset-ensemble-stats + steps: + - - name: get-stats + template: get-stats + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.get_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - name: run + value: "{{item}}" + withSequence: + start: "0" + end: "{{workflow.parameters.runs_count_minus_one}}" + - - name: combine-stats + template: combine-stats + arguments: + parameters: + - name: python_script + value: "{{workflow.parameters.combine_stats_script}}" + - name: config + value: "{{workflow.parameters.config}}" + - name: get-stats + tolerations: + - effect: NoSchedule + key: dedicated + value: med-sim-pool + inputs: + parameters: + - name: python_script + - name: config + - name: run + container: + image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2 + resources: + limits: + cpu: "8000m" + memory: "27Gi" + requests: + cpu: "7500m" + memory: "27Gi" + command: ["bash", "-c", "-e"] + args: + - | + cat << EOF > script.py + {{inputs.parameters.python_script}} + EOF + + cat << EOF > config.yaml + {{inputs.parameters.config}} + EOF + + run={{inputs.parameters.run}} + + python script.py config.yaml ${run} + env: + - name: GOOGLE_APPLICATION_CREDENTIALS + value: /secret/gcp-credentials/key.json + - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE + value: /secret/gcp-credentials/key.json + volumeMounts: + - mountPath: /secret/gcp-credentials + name: gcp-key-secret + - name: combine-stats + tolerations: + - effect: NoSchedule + key: dedicated + value: med-sim-pool + inputs: + parameters: + - name: python_script + - name: config + container: + image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2 + resources: + limits: + cpu: "8000m" + memory: "27Gi" + requests: + cpu: "7500m" + memory: "27Gi" + command: ["bash", "-c", "-e"] + args: + - | + cat << EOF > script.py + {{inputs.parameters.python_script}} + EOF + + cat << EOF > config.yaml + {{inputs.parameters.config}} + EOF + + python script.py config.yaml + env: + - name: GOOGLE_APPLICATION_CREDENTIALS + value: /secret/gcp-credentials/key.json + - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE + value: /secret/gcp-credentials/key.json + volumeMounts: + - mountPath: /secret/gcp-credentials + name: gcp-key-secret + - name: upload-beaker-stats + tolerations: + - effect: NoSchedule + key: dedicated + value: med-sim-pool + inputs: + parameters: + - name: python_script + - name: config + container: + image: us.gcr.io/vcm-ml/beaker-py + resources: + limits: + cpu: "8000m" + memory: "27Gi" + requests: + cpu: "7500m" + memory: "27Gi" + command: ["bash", "-c", "-e"] + args: + - | + cat << EOF > script.py + {{inputs.parameters.python_script}} + EOF + + cat << EOF > config.yaml + {{inputs.parameters.config}} + EOF + + python script.py config.yaml + env: + - name: GOOGLE_APPLICATION_CREDENTIALS + value: /secret/gcp-credentials/key.json + - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE + value: /secret/gcp-credentials/key.json + - name: BEAKER_TOKEN + valueFrom: + secretKeyRef: + name: beaker-key-andrep + key: BEAKER_USER_KEY + volumeMounts: + - mountPath: /secret/gcp-credentials + name: gcp-key-secret diff --git a/scripts/data_process/compute_dataset_e3smv2.py b/scripts/data_process/compute_dataset_e3smv2.py index 675db96..91e3fa4 100755 --- a/scripts/data_process/compute_dataset_e3smv2.py +++ b/scripts/data_process/compute_dataset_e3smv2.py @@ -18,8 +18,10 @@ import click import numpy as np import xarray as xr -import xpartition # noqa: 401 -from compute_dataset_fv3gfs import ( +import xpartition # noqa: F401 +from compute_dataset import ( + DatasetComputationConfig, + DatasetConfig, assert_column_integral_of_moisture_is_conserved, assert_global_dry_air_mass_conservation, assert_global_moisture_conservation, @@ -34,142 +36,17 @@ from xtorch_harmonics import roundtrip_filter # default paths for input/output; can be changed when calling this script -INPUT_DIR = "/global/cfs/cdirs/e3sm/golaz/E3SM/fme/20230614.v2.LR.F2010/post/atm/180x360_gaussian/ts" # noqa: 501 -TIME_INVARIANT_INPUT_DIR = "/global/cfs/cdirs/m4331/jpduncan/e3smv2/time_invariant" -OUTPUT_URL = "/pscratch/sd/j/jpduncan/ai2/zarr/2023-11-22-e3smv2-vertically-resolved-1deg-fme-dataset.zarr" # noqa: 501 - -# these are subdirs of INPUT_DIR -INSTANT = "6hourly_instant/1yr" -MEAN = "6hourly/1yr" +INPUT_DIR = "/global/cfs/cdirs/e3sm/golaz/E3SM/fme/20230614.v2.LR.F2010/post/atm/180x360_gaussian/ts" # noqa: E501 +OUTPUT_URL = "/pscratch/sd/j/jpduncan/ai2/zarr/2023-11-22-e3smv2-vertically-resolved-1deg-fme-dataset.zarr" # noqa: E501 REFERENCE_PRESSURE = 1e5 # Pa LIQUID_PRECIP_DENSITY = 1e3 # kg/m**3 LATENT_HEAT_OF_VAPORIZATION = 2.501e6 # J/kg -# 6-hourly instant -SURFACE_PRESSURE = "PS" -SURFACE_TEMPERATURE = "TS" -AIR_TEMPERATURE = "T" -EASTWARD_WIND = "U" -NORTHWARD_WIND = "V" -MEAN_SEA_LEVEL_PRESSURE = "PSL" -GEOPOTENTIAL = "Z" -RELATIVE_HUMIDITY = "RH" -TOTAL_COLUMN_WATER_VAPOR = "TMQ" - -# 6-hourly mean -TOTAL_PRECIP_RATE = "PRECT" # m/s -LATENT_HEAT_FLUX = "LHFLX" - # derived variable names -SPECIFIC_TOTAL_WATER = "specific_total_water" -PRECIPITABLE_WATER_PATH = "precipitable_water_path" # total from E3SMv2 model outputs -TOTAL_WATER_PATH = "total_water_path" # computed by vertical integration of 3D vars -PRECIP_RATE = "surface_precipitation_rate" SURFACE_UP_LONGWAVE_FLUX = "surface_upward_longwave_flux" SURFACE_UP_SHORTWAVE_FLUX = "surface_upward_shortwave_flux" TOA_UP_SHORTWAVE_FLUX = "top_of_atmos_upward_shortwave_flux" -PRESSURE_THICKNESS = "pressure_thickness_of_atmospheric_layer" - -# dims -TIME_DIM = "time" -HORIZONTAL_DIMS = ["lat", "lon"] -LATITUDE_DIM = "lat" -VERTICAL_DIM = "lev" -VERTICAL_INTERFACE_DIM = "ilev" -HYBRID_LEVEL_COEFFS = ["hyai", "hybi"] - -CHUNKS = {"time": 10, "lat": 180, "lon": 360} - -# assumed to be found in INSTANT dir -FOURCASTNET_VANILLA = { - EASTWARD_WIND: ["1000", "850", "500"], - NORTHWARD_WIND: ["1000", "850", "500"], - AIR_TEMPERATURE: ["850", "500"], - GEOPOTENTIAL: ["1000", "500", "850", "050"], - RELATIVE_HUMIDITY: ["850", "500"], - SURFACE_PRESSURE: [""], - "TREFHT": [""], # temp at 2m - MEAN_SEA_LEVEL_PRESSURE: [""], - TOTAL_COLUMN_WATER_VAPOR: [""], -} - -# the variables / filename prefixes we need from the raw E3SMv2 output -INPUT_VARIABLE_NAMES = { - INSTANT: [ - SURFACE_PRESSURE, - SURFACE_TEMPERATURE, - AIR_TEMPERATURE, - EASTWARD_WIND, - NORTHWARD_WIND, - "Q", - "CLDLIQ", - "CLDICE", - "RAINQM", - "SNOWQM", - TOTAL_COLUMN_WATER_VAPOR, - "TGCLDLWP", - "TGCLDIWP", - "OCNFRAC", - "LANDFRAC", - "ICEFRAC", - ], - MEAN: [ - TOTAL_PRECIP_RATE, - LATENT_HEAT_FLUX, - "SHFLX", - "FLNS", - "FLDS", - "FSNS", - "FSDS", - "FSNTOA", - "SOLIN", - "FLUT", - # only for water budget dataset: - "PRECSC", - "PRECSL", - "QFLX", - ], - "time_invariant": ["PHIS"], -} - - -WATER_SPECIES_NAMES = [ - "Q", - "CLDLIQ", - "CLDICE", - "RAINQM", - "SNOWQM", -] - -VARNAMES_3D = [ - AIR_TEMPERATURE, - EASTWARD_WIND, - NORTHWARD_WIND, -] + WATER_SPECIES_NAMES - -PRECIPITABLE_WATER_PATH_NAMES = [TOTAL_COLUMN_WATER_VAPOR, "TGCLDLWP", "TGCLDIWP"] - -VERTICALLY_RESOLVED_NAMES = [ - SPECIFIC_TOTAL_WATER, - AIR_TEMPERATURE, - NORTHWARD_WIND, - EASTWARD_WIND, -] - -TIME_DERIVATIVE_NAMES = [TOTAL_WATER_PATH] - -# computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2023-06-09-e3smv2-vertical-interface-indices.ipynb # noqa: 501 -VERTICAL_LEVEL_INTERFACES = [ - (0, 19), - (19, 30), - (30, 38), - (38, 44), - (44, 48), - (48, 53), - (53, 61), - (61, 72), -] RAD_FLUX_FORMULAS = { SURFACE_UP_LONGWAVE_FLUX: (lambda x, y: x + y, "FLNS", "FLDS"), @@ -177,6 +54,9 @@ TOA_UP_SHORTWAVE_FLUX: (lambda x, y: x - y, "SOLIN", "FSNTOA"), } +SURFACE_PRECIPITATION = "surface_precipitation_rate" +PRECIPITABLE_WATER_PATH_NAMES = ["TMQ", "TGCLDLWP", "TGCLDIWP"] + DROP_VARIABLE_NAMES = { "2D": [ # variables to drop when opening 2D vars "time_bnds", @@ -197,46 +77,8 @@ "hyam", "hybm", ], - "POST": [ # variables to drop at the end - AIR_TEMPERATURE, - EASTWARD_WIND, - NORTHWARD_WIND, - SPECIFIC_TOTAL_WATER, - PRESSURE_THICKNESS, - TOTAL_PRECIP_RATE, - PRECIPITABLE_WATER_PATH, - TOTAL_COLUMN_WATER_VAPOR, - VERTICAL_DIM, - VERTICAL_INTERFACE_DIM, - HYBRID_LEVEL_COEFFS[0], - HYBRID_LEVEL_COEFFS[1], - "Q", - "CLDLIQ", - "CLDICE", - "RAINQM", - "SNOWQM", - "TGCLDLWP", - "TGCLDIWP", - "FLNS", - "FSNS", - "FSNTOA", - "PRECSC", - "PRECSL", - "QFLX", - ], } -# dataset of 2D vars for checking water conservation -WATER_BUDGET_DATASET_VARS = PRECIPITABLE_WATER_PATH_NAMES + [ - SURFACE_PRESSURE, - PRECIPITABLE_WATER_PATH, - TOTAL_PRECIP_RATE, - LATENT_HEAT_FLUX, - "PRECSC", - "PRECSL", - "QFLX", -] - def expand_names_by_level(variables: MutableMapping[str, List[str]]) -> List[str]: names = [] @@ -246,45 +88,50 @@ def expand_names_by_level(variables: MutableMapping[str, List[str]]) -> List[str def get_nc_paths( - base_dir: str, var_names: Optional[List[str]] = None + base_dir: str, var_names: Sequence[str] ) -> MutableMapping[str, List[str]]: - if var_names is None: - paths = {"time_invariant": list(glob(os.path.join(base_dir, f"*.nc")))} - else: - paths = { - var_name: sorted(list(glob(os.path.join(base_dir, f"{var_name}_*.nc")))) - for var_name in var_names - } + paths = { + var_name: sorted(list(glob(os.path.join(base_dir, f"{var_name}_*.nc")))) + for var_name in var_names + } + return paths + + +def get_time_invariant_nc_paths( + base_dir: Optional[str], +) -> MutableMapping[str, List[str]]: + paths = { + "time_invariant": list(glob(os.path.join(base_dir, f"*.nc"))) # type: ignore + } return paths def open_dataset( dataset_dirs: MutableMapping[str, str], - time_invariant_dir: str, - input_variable_names: MutableMapping[str, List[str]] = INPUT_VARIABLE_NAMES, - varnames_3d: List[str] = VARNAMES_3D, - drop_variable_names: MutableMapping[str, List[str]] = DROP_VARIABLE_NAMES, - chunks: MutableMapping[str, int] = CHUNKS, - vanilla: bool = False, + config: DatasetComputationConfig, ) -> xr.Dataset: """Open datasets from NetCDF files in directory that match input variable names.""" - if vanilla: - var_names = expand_names_by_level(FOURCASTNET_VANILLA) - var_paths = get_nc_paths(dataset_dirs[INSTANT], var_names) - else: - var_paths = {} - for key in dataset_dirs.keys(): - var_paths.update(get_nc_paths(dataset_dirs[key], input_variable_names[key])) - var_paths.update(get_nc_paths(time_invariant_dir)) + var_paths: MutableMapping[str, List[str]] = {} + input_variable_names = config.variable_sources + for key in dataset_dirs.keys(): + var_paths.update(get_nc_paths(dataset_dirs[key], input_variable_names[key])) + var_paths.update(get_time_invariant_nc_paths(config.time_invariant_dir)) print( f"Opening {len(list(chain.from_iterable(var_paths.values())))} files with " f"{len(var_paths.keys())} vars..." ) + standard_names = config.standard_names + varnames_3D = [ + standard_names.air_temperature, + standard_names.eastward_wind, + standard_names.northward_wind, + ] + standard_names.water_species + chunks = config.chunking.get_chunks(config.standard_names) datasets = {} start = time.time() if "time_invariant" in var_paths: for path in var_paths["time_invariant"]: - ds = xr.open_dataset(path).drop(drop_variable_names["2D"], errors="ignore") + ds = xr.open_dataset(path).drop(DROP_VARIABLE_NAMES["2D"], errors="ignore") if "time" in ds.coords: ds = ds.isel(time=0, drop=True) for varname in input_variable_names["time_invariant"]: @@ -293,10 +140,10 @@ def open_dataset( del var_paths["time_invariant"] for varname, paths in var_paths.items(): var_start = time.time() - if varname in varnames_3d: - drop_vars = drop_variable_names["3D"] + if varname in varnames_3D: + drop_vars = DROP_VARIABLE_NAMES["3D"] else: - drop_vars = drop_variable_names["2D"] + drop_vars = DROP_VARIABLE_NAMES["2D"] datasets[varname] = xr.open_mfdataset( paths, chunks=chunks, @@ -311,12 +158,12 @@ def open_dataset( def compute_pressure_thickness( ds: xr.Dataset, - vertical_dim: str = VERTICAL_DIM, - vertical_interface_dim: str = VERTICAL_INTERFACE_DIM, - hybrid_level_coeffs: List[str] = HYBRID_LEVEL_COEFFS, - reference_pressure: float = REFERENCE_PRESSURE, - surface_pressure: str = SURFACE_PRESSURE, - output_name: str = PRESSURE_THICKNESS, + vertical_dim: str, + vertical_interface_dim: str, + hybrid_level_coeffs: List[str], + reference_pressure: float, + surface_pressure: str, + output_name: str, ): hyai, hybi = hybrid_level_coeffs sfc_pressure = ds[surface_pressure].expand_dims( @@ -336,9 +183,9 @@ def compute_pressure_thickness( def compute_coarse_ak_bk( ds: xr.Dataset, - interface_indices: Sequence[Tuple[int, int]] = VERTICAL_LEVEL_INTERFACES, - z_dim="ilev", - hybrid_level_coeffs: List[str] = HYBRID_LEVEL_COEFFS, + interface_indices: Sequence[Tuple[int, int]], + z_dim: str, + hybrid_level_coeffs: List[str], reference_pressure: float = REFERENCE_PRESSURE, ): """Return dataset with scalar ak and bk coordinates that define coarse @@ -376,9 +223,9 @@ def compute_coarse_ak_bk( def compute_surface_precipitation_rate( ds: xr.Dataset, - total_precip_rate_name: str = TOTAL_PRECIP_RATE, + total_precip_rate_name, liquid_precip_density: float = LIQUID_PRECIP_DENSITY, - output_name: str = PRECIP_RATE, + output_name: str = SURFACE_PRECIPITATION, ): precip_mass_flux = ds[total_precip_rate_name] * liquid_precip_density precip_mass_flux.attrs["units"] = "kg/m2/s" @@ -388,8 +235,8 @@ def compute_surface_precipitation_rate( def compute_precipitable_water_path( ds: xr.Dataset, - precipitable_water_path_names: List[str] = PRECIPITABLE_WATER_PATH_NAMES, - output_name: str = PRECIPITABLE_WATER_PATH, + output_name: str, + precipitable_water_path_names: List[str], ): water_path = sum([ds[name] for name in precipitable_water_path_names]) water_path.attrs["units"] = "kg/m2" @@ -412,55 +259,83 @@ def compute_rad_fluxes( def construct_lazy_dataset( + config: DatasetComputationConfig, dataset_dirs: MutableMapping[str, str], - time_invariant_dir: str, - vanilla: bool = False, - sht_roundtrip: bool = False, ) -> xr.Dataset: start = time.time() + standard_names = config.standard_names print(f"Opening dataset...") - ds = open_dataset(dataset_dirs, time_invariant_dir, vanilla=vanilla) + ds = open_dataset(dataset_dirs, config) print(f"Dataset opened in {time.time() - start:.2f} s total.") print(f"Input dataset size is {ds.nbytes / 1e9} GB") - if sht_roundtrip: - ds = roundtrip_filter(ds, lat_dim="lat", lon_dim="lon") - if not vanilla: - ds = compute_pressure_thickness(ds) - ds = compute_rad_fluxes(ds) - ds = compute_surface_precipitation_rate(ds) - ds = compute_precipitable_water_path(ds) # only used for conservation check - ds = compute_specific_total_water(ds, WATER_SPECIES_NAMES, SPECIFIC_TOTAL_WATER) - ds = compute_vertical_coarsening( - ds, - VERTICALLY_RESOLVED_NAMES, - VERTICAL_LEVEL_INTERFACES, - VERTICAL_DIM, - PRESSURE_THICKNESS, - ) - ds = compute_column_moisture_integral( + if config.roundtrip_fraction_kept is not None: + ds = roundtrip_filter( ds, - SPECIFIC_TOTAL_WATER, - TOTAL_WATER_PATH, - PRESSURE_THICKNESS, - VERTICAL_DIM, + lat_dim=standard_names.latitude_dim, + lon_dim=standard_names.longitude_dim, + fraction_modes_kept=config.roundtrip_fraction_kept, ) - ds[TOTAL_WATER_PATH].attrs["units"] = "kg/m2" # change to E3SMv2 format - ds = compute_tendencies(ds, TIME_DERIVATIVE_NAMES, TIME_DIM) - ds = compute_column_advective_moisture_tendency( - ds, - f"tendency_of_{TOTAL_WATER_PATH}", - LATENT_HEAT_FLUX, - PRECIP_RATE, - LATENT_HEAT_OF_VAPORIZATION, - ) - ak_bk_ds = compute_coarse_ak_bk(ds, VERTICAL_LEVEL_INTERFACES) - ds = xr.merge([ds, ak_bk_ds]) - ds_dirs = list(dataset_dirs.values()) - else: - ds_dirs = [dataset_dirs[INSTANT]] - ds = ds.chunk(CHUNKS).astype(np.float32) + + ds = compute_pressure_thickness( + ds, + vertical_dim=standard_names.vertical_dim, + vertical_interface_dim=standard_names.vertical_interface_dim, + hybrid_level_coeffs=standard_names.hybrid_level_coeffs, + reference_pressure=REFERENCE_PRESSURE, + surface_pressure=standard_names.surface_pressure, + output_name=standard_names.pressure_thickness, + ) + ds = compute_rad_fluxes(ds) + ds = compute_surface_precipitation_rate( + ds, + total_precip_rate_name=standard_names.precip_rate, + ) + water_species_name = [ + item for item in standard_names.water_species if item.lower() != "none" + ] + ds = compute_specific_total_water( + ds, + water_condensate_names=water_species_name, + output_name=standard_names.specific_total_water, + ) + ds = compute_vertical_coarsening( + ds, + vertically_resolved_names=standard_names.vertically_resolved, + interface_indices=config.vertical_coarsening_indices, + dim=standard_names.vertical_dim, + pressure_thickness_name=standard_names.pressure_thickness, + ) + ds = compute_column_moisture_integral( + ds, + input_name=standard_names.specific_total_water, + output_name=standard_names.total_water_path, + pressure_thickness_name=standard_names.pressure_thickness, + dim=standard_names.vertical_dim, + ) + ds = compute_tendencies( + ds, + time_derivative_names=standard_names.time_derivative_names, + dim=standard_names.time_dim, + ) + ds = compute_column_advective_moisture_tendency( + ds, + pwat_tendency=standard_names.pwat_tendency, + latent_heat_flux=standard_names.latent_heat_flux, + precip=standard_names.precip_rate, + latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION, + ) + ak_bk_ds = compute_coarse_ak_bk( + ds, + interface_indices=config.vertical_coarsening_indices, + z_dim=standard_names.vertical_interface_dim, + hybrid_level_coeffs=standard_names.hybrid_level_coeffs, + ) + ds = xr.merge([ds, ak_bk_ds]) + ds_dirs = list(dataset_dirs.values()) + chunks = config.chunking.get_chunks(config.standard_names) + ds = ds.chunk(chunks).astype(np.float32) ds.attrs["history"] = ( - "Dataset computed by ace/scripts/e3smv2_data_process" + "Dataset computed by full-model/scripts/e3smv2_data_process" "/compute_dataset_e3smv2.py" f" script, using inputs from the following directories: {ds_dirs}." ) @@ -474,94 +349,121 @@ def construct_lazy_dataset( @click.command() -@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.") -@click.option("--subsample", is_flag=True, help="Subsample the data before writing.") -@click.option("--vanilla", is_flag=True, help="Compute vanilla FourCastNet dataset.") -@click.option("--check-conservation", is_flag=True, help="Check conservation.") -@click.option( - "--water-budget-dataset", - is_flag=True, - help="Create a dataset of 2D vars for checking the water budget.", -) -@click.option("--sht-roundtrip", is_flag=True, help="SHT roundtrip as a first step.") +@click.option("--config", help="Path to dataset configuration YAML file.") @click.option( "-i", "--input-dir", default=INPUT_DIR, help="Directory in which to find time-varying input ncs.", ) +@click.option("-o", "--output", default=OUTPUT_URL, help="URL to write output to.") +@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.") +@click.option("--subsample", is_flag=True, help="Subsample the data before writing.") +@click.option("--check-conservation", is_flag=True, help="Check conservation.") @click.option( - "-t", - "--time-invariant-input-dir", - default=TIME_INVARIANT_INPUT_DIR, - help="Directory in which to find time-invariant input ncs.", + "--water-budget-dataset", + is_flag=True, + help="Create a dataset of 2D vars for checking the water budget.", ) -@click.option("-o", "--output", default=OUTPUT_URL, help="URL to write output to.") -@click.option("--n-split", default=100, help="Number of steps to split job over.") @click.option("--n-workers", default=4, help="Number of Dask workers.") def main( + config, + input_dir, + output, debug, subsample, - vanilla, - sht_roundtrip, check_conservation, water_budget_dataset, - input_dir, - time_invariant_input_dir, - output, - n_split, n_workers, ): xr.set_options(keep_attrs=True) _ = Client(n_workers=n_workers) - - dataset_dirs = { - INSTANT: os.path.join(input_dir, INSTANT), - MEAN: os.path.join(input_dir, MEAN), - } - ds = construct_lazy_dataset( - dataset_dirs, time_invariant_input_dir, vanilla, sht_roundtrip - ) + config = DatasetConfig.from_file(config).dataset_computation + standard_names = config.standard_names + dataset_dirs = {} + for key in config.variable_sources.keys(): + if key != "time_invariant": + dataset_dirs[key] = os.path.join(input_dir, key) + ds = construct_lazy_dataset(config, dataset_dirs) if subsample: ds = ds.isel(time=slice(10, 13)) if check_conservation: # these currently fail + ds = compute_precipitable_water_path( + ds, + output_name=standard_names.precipitable_water_path, + precipitable_water_path_names=PRECIPITABLE_WATER_PATH_NAMES, + ) assert_column_integral_of_moisture_is_conserved( - ds, PRECIPITABLE_WATER_PATH, TOTAL_WATER_PATH + ds, standard_names.precipitable_water_path, standard_names.total_water_path ) assert_global_dry_air_mass_conservation( ds, - dims=HORIZONTAL_DIMS, - surface_pressure_name=SURFACE_PRESSURE, - total_water_path_name=TOTAL_WATER_PATH, - latitude_dim=LATITUDE_DIM, - time_dim=TIME_DIM, + dims=standard_names.horizontal_dims, + surface_pressure_name=standard_names.surface_pressure, + total_water_path_name=standard_names.total_water_path, + latitude_dim=standard_names.latitude_dim, + time_dim=standard_names.time_dim, ) assert_global_moisture_conservation( ds, - dims=HORIZONTAL_DIMS, - latitude_dim=LATITUDE_DIM, - total_water_path_name=TOTAL_WATER_PATH, - latent_heat_flux_name=LATENT_HEAT_FLUX, + dims=standard_names.horizontal_dims, + latitude_dim=standard_names.latitude_dim, + total_water_path_name=standard_names.total_water_path, + latent_heat_flux_name=standard_names.latent_heat_flux, latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION, - precip_rate_name=PRECIP_RATE, - time_dim=TIME_DIM, + precip_rate_name=standard_names.precip_rate, + time_dim=standard_names.time_dim, ) if water_budget_dataset: - ds = ds[WATER_BUDGET_DATASET_VARS] + water_budget_dataset_vars = [ + standard_names.surface_pressure, + standard_names.precipitable_water_path, + standard_names.precip_rate, + standard_names.latent_heat_flux, + "PRECSC", + "PRECSL", + "QFLX", + "TMQ", + "TGCLDLWP", + "TGCLDIWP", + ] + + ds = ds[water_budget_dataset_vars] else: - ds = ds.drop(DROP_VARIABLE_NAMES["POST"], errors="ignore") + dropped_variables = ( + [ + item + for item in standard_names.dropped_variables + if item.lower() != "none" + ] + + standard_names.hybrid_level_coeffs + + [ + standard_names.precip_rate, + standard_names.vertical_interface_dim, + "TMQ", + "TGCLDLWP", + "TGCLDIWP", + "FLNS", + "FSNS", + "FSNTOA", + "PRECSC", + "PRECSL", + "QFLX", + ] + ) + ds = ds.drop(dropped_variables, errors="ignore") print(f"Output dataset size is {ds.nbytes / 1e9} GB") if debug: with xr.set_options(display_max_rows=500): print(ds) else: ds.partition.initialize_store(output) - for i in range(n_split): - print(f"Writing segment {i + 1} / {n_split}") + for i in range(config.n_split): + print(f"Writing segment {i + 1} / {config.n_split}") with ProgressBar(): ds.partition.write( - output, n_split, ["time"], i, collect_variable_writes=True + output, config.n_split, ["time"], i, collect_variable_writes=True ) diff --git a/scripts/data_process/compute_hpx_dataset.py b/scripts/data_process/compute_hpx_dataset.py new file mode 100644 index 0000000..5556f5f --- /dev/null +++ b/scripts/data_process/compute_hpx_dataset.py @@ -0,0 +1,228 @@ +# This script is used to compute a training dataset from the "raw" +# FV3GFS data stored in zarr form on GCS. + +# The dependencies of this script are installed in the "fv3net" conda environment +# which can be installed using fv3net's Makefile. See +# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310 + +# The resulting dataset is about 194GB (the input is about 2.5TB). Running this script +# on my 8-CPU VM takes about 2.5 hours. See "compute_dataset_fv3gfs_argo_workflow.yaml" +# for a workflow which parallelizes this script across the 11-member ensemble and runs +# it on our GKE cluster. + +import os +import pdb +import sys +from typing import Tuple + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import multiprocessing as mp +from functools import partial + +import click +import earth2grid +import numpy as np +import torch +import xarray as xr +import xpartition # noqa +from compute_dataset import ( + LATENT_HEAT_OF_VAPORIZATION, + DatasetComputationConfig, + DatasetConfig, + DLWPChunkingConfig, + DLWPNameMapping, + assert_column_integral_of_moisture_is_conserved, + assert_global_dry_air_mass_conservation, + assert_global_moisture_conservation, + get_dataset_urls, + open_datasets, +) + + +def regrid_tensor(x, regrid_func, shape): + data = regrid_func(torch.tensor(x, dtype=torch.double)) + return data.numpy().reshape(shape) + + +def _pool_func(ds, store, n_partitions, partition_dims, i): + ds.partition.write(store, n_partitions, partition_dims, i) + print(f"Finished writing partition {i}") + + +def hpx_regrid( + ds: xr.Dataset, + dlwp_names: DLWPNameMapping, + level: int = 6, # regrid resolution + n_side: int = 64, +) -> Tuple[xr.Dataset, xr.Dataset]: + lat_long_names = dlwp_names.lat_lon_dims + longitude = lat_long_names[1] + latitude = lat_long_names[0] + lons = ds[longitude] + lats = ds[latitude] + + hpx = earth2grid.healpix.Grid( + level=level, pixel_order=earth2grid.healpix.HEALPIX_PAD_XY + ) + src = earth2grid.latlon.LatLonGrid(lat=list(lats), lon=list(lons)) + # Regridder + regrid = earth2grid.get_regridder(src, hpx) + + ds_regridded = xr.apply_ufunc( + regrid_tensor, + ds, + input_core_dims=[[latitude, longitude]], + output_core_dims=[["face", "height", "width"]], + output_sizes={"face": 12, "height": n_side, "width": n_side}, + output_dtypes=[float], + dask="parallelized", + vectorize=True, + on_missing_core_dim="copy", + kwargs={"regrid_func": regrid, "shape": (12, n_side, n_side)}, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + # Assign coordinates to the regridded dataset + time_coords = ds.coords["time"] + nside_coords = np.arange(n_side) + grid_coords = np.arange(12) + ds_regridded = ds_regridded.assign_coords( + time=time_coords, + face=grid_coords, + height=nside_coords, + width=nside_coords, + ) + + return ds_regridded + + +def construct_hpx_dataset( + config: DatasetComputationConfig, + run_directory: str, + output_directory: str, + toy_dataset: bool = False, +) -> xr.Dataset: + dlwp_names = config.standard_names + if not isinstance(dlwp_names, DLWPNameMapping): + raise TypeError("Expected to be passed type of DLWPNameMapping.") + dlwp_chunking = config.chunking + if not isinstance(dlwp_chunking, DLWPChunkingConfig): + raise TypeError("Expected to be passed type of DLWPChunkingConfig.") + + urls = get_dataset_urls(config, run_directory) + print(urls) + ds = open_datasets(config, urls) + for var in ds: + del ds[var].encoding["chunks"] + del ds[var].encoding["preferred_chunks"] + print(f"Input dataset size is {ds.nbytes / 1e9} GB") + if toy_dataset: + ds = ds.isel(time=slice(0, 200)) + # We would like to: + + # 1. map to healpix mesh + ds = hpx_regrid( + ds=ds, + dlwp_names=dlwp_names, + n_side=64, + ) + print(f"After regrid: {ds}") + + # 2. chunk and save + chunks = config.chunking.get_chunks(dlwp_names) + ds = ds.chunk(chunks) + ds.attrs["history"] = ( + "Dataset computed by full-model/scripts/data_process" + "/compute_hpx_dataset.py" + f" script, using following input zarrs: {urls.values()}." + ) + ds.attrs["vertical_coordinate"] = ( + "The pressure at level interfaces can by computed as " + "p_i = ak_i + bk_i * PRESsfc, where PRESsfc is the surface pressure and the " + "p_i pressure corresponds to the interface at the top of the i'th finite " + "volume layer, counting down from the top of atmosphere." + ) + ds = ds.rename(config.renaming) + return ds + + +@click.command() +@click.option("--config", help="Path to dataset configuration YAML file.") +@click.option("--run-directory", help="Path to reference run directory.") +@click.option("--output-store", help="Path to output zarr store.") +@click.option("--debug", is_flag=True, help="Print metadata instead of writing output.") +@click.option("--subsample", is_flag=True, help="Subsample the data before writing.") +@click.option("--check-conservation", is_flag=True, help="Check conservation.") +@click.option("--num-processes", default=16, help="Number of processes to spin up.") +def main( + config, + run_directory, + output_store, + debug, + subsample, + check_conservation, + num_processes, +): + config = DatasetConfig.from_file(config).dataset_computation + dlwp_names = config.standard_names + print(f"--run-directory is {run_directory}") + print(f"--output-store is {output_store}") + ds = construct_hpx_dataset( + config=config, + run_directory=run_directory, + output_directory=output_store, + toy_dataset=False, + ) + if subsample: + ds = ds.isel(time=slice(10, 13)) + if check_conservation: + assert_column_integral_of_moisture_is_conserved( + ds, + precipitable_water_path_name=dlwp_names.precipitable_water_path, + total_water_path_name=dlwp_names.total_water_path, + ) + assert_global_dry_air_mass_conservation( + ds, + dims=dlwp_names.horizontal_dims, + surface_pressure_name=dlwp_names.surface_pressure, + total_water_path_name=dlwp_names.total_water_path, + latitude_dim=dlwp_names.latitude_dim, + time_dim=dlwp_names.time_dim, + ) + assert_global_moisture_conservation( + ds, + dims=dlwp_names.horizontal_dims, + latitude_dim=dlwp_names.latitude_dim, + total_water_path_name=dlwp_names.total_water_path, + latent_heat_flux_name=dlwp_names.latent_heat_flux, + latent_heat_of_vaporization=LATENT_HEAT_OF_VAPORIZATION, + precip_rate_name=dlwp_names.precip_rate, + time_dim=dlwp_names.time_dim, + ) + drop_vars = [var for var in dlwp_names.dropped_variables if var in ds] + ds = ds.drop(drop_vars) + print(f"Output dataset size is {ds.nbytes / 1e9} GB") + + if debug: + with xr.set_options(display_max_rows=500): + print(ds) + else: + n_partitions = config.n_split + partition_dims = [dlwp_names.time_dim] + store = f"{output_store}.zarr" + ds.partition.initialize_store(store) + + with mp.get_context("forkserver").Pool(num_processes) as pool: + pool.map( + partial(_pool_func, ds, store, n_partitions, partition_dims), + range(n_partitions), + ) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"An error occurred: {e}", file=sys.stderr) + pdb.post_mortem() # Start the debugger + raise # Re-raise the exception to preserve the traceback diff --git a/scripts/data_process/compute_hpx_dataset.sh b/scripts/data_process/compute_hpx_dataset.sh new file mode 100755 index 0000000..92f874a --- /dev/null +++ b/scripts/data_process/compute_hpx_dataset.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Function to check if directory is in Python path and add it if not +function add_to_pythonpath() { + local dir_to_add="$1" + if [[ ":$PYTHONPATH:" != *":$dir_to_add:"* ]]; then + export PYTHONPATH="$dir_to_add:$PYTHONPATH" + fi +} + +# Add your own full-model directory to Python path if not already included +add_to_pythonpath "~/full-model" + +ARGO=false + +while [[ "$#" -gt 0 ]] +do + case $1 in + --config) CONFIG="$2" + shift;; + --argo) ARGO="$2" + shift;; + *) echo "Unknown parameter passed: $1" + exit 1;; + esac + shift +done + +if [[ -z "${CONFIG}" ]] +then + echo "Option --config missing" + exit 1 +fi + +run_directory=$(yq -r '.runs.run_directory' ${CONFIG}) +output_directory=$(yq -r '.data_output_directory' ${CONFIG}) + +if [[ "$ARGO" == "true" ]] +then + output=$(argo submit full-model/scripts/data_process/compute_hpx_dataset_argo_workflow.yaml \ + -p python_script="$(< full-model/scripts/data_process/compute_hpx_dataset.py)" \ + -p config="$(< ${CONFIG})" \ + -p run_directory="${run_directory}" \ + -p output_directory="${output_directory}") + + job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}') + echo "Argo job submitted: $job_name" +else + python3 full-model/scripts/data_process/compute_hpx_dataset.py --config="${CONFIG}" \ + --run-directory="${run_directory}" \ + --output-store="${output_directory}" +fi \ No newline at end of file diff --git a/scripts/data_process/compute_repeating_forcing.py b/scripts/data_process/compute_repeating_forcing.py index fdef451..4ed7659 100755 --- a/scripts/data_process/compute_repeating_forcing.py +++ b/scripts/data_process/compute_repeating_forcing.py @@ -97,7 +97,7 @@ def main(n_times: int, input_dir: Path, output_dir: Path, repeat_variables, nc_f time_coord = xr.cftime_range( ds.time.item(0), periods=len(ds.time) * n_times, - freq=f"{dt}H", + freq=f"{dt}h", calendar=ds.time.dt.calendar, ) diff --git a/scripts/data_process/compute_stats.sh b/scripts/data_process/compute_stats.sh new file mode 100755 index 0000000..ceb2c2f --- /dev/null +++ b/scripts/data_process/compute_stats.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +./compute_dataset.sh --stats-only "$@" diff --git a/scripts/data_process/configs/e3sm-1deg-8layer.yaml b/scripts/data_process/configs/e3sm-1deg-8layer.yaml new file mode 100644 index 0000000..1432f89 --- /dev/null +++ b/scripts/data_process/configs/e3sm-1deg-8layer.yaml @@ -0,0 +1,85 @@ +runs: + 2024-07-10-e3smv2-1deg-testing: "" +data_output_directory: /global/cfs/cdirs/m4492/fme-preprocess/zarr/ +stats: + output_directory: /global/cfs/cdirs/m4492/fme-preprocess/2024-07-10-e3smv2-1deg-testing + start_date: "1970-01-01" + end_date: "1970-12-31" + data_type: E3SMV2 + beaker_dataset: e3sm-1deg-8layers-stats-1970 # this is not used in e3sm data processing +dataset_computation: + chunking: + time_dim: 10 + latitude_dim: 180 + longitude_dim: 360 + reference_vertical_coordinate_file: None + time_invariant_dir: /global/cfs/cdirs/m4331/jpduncan/e3smv2/time_invariant + vertical_coarsening_indices: +# computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2023-06-09-e3smv2-vertical-interface-indices.ipynb + - [0, 19] + - [19, 30] + - [30, 38] + - [38, 44] + - [44, 48] + - [48, 53] + - [53, 61] + - [61, 72] + roundtrip_fraction_kept: 1.0 + n_split: 100 + variable_sources: + time_invariant: + - PHIS + 6hourly_instant/1yr: + - PS + - TS + - T + - U + - V + - Q + - CLDLIQ + - CLDICE + - RAINQM + - SNOWQM + - TMQ + - TGCLDLWP + - TGCLDIWP + - OCNFRAC + - LANDFRAC + - ICEFRAC + 6hourly/1yr: + - PRECT + - LHFLX + - SHFLX + - FLNS + - FLDS + - FSNS + - FSDS + - FSNTOA + - SOLIN + - FLUT + - PRECSC + - PRECSL + - QFLX + standard_names: + longitude_dim: lon + latitude_dim: lat + vertical_dim: lev + vertical_interface_dim: ilev + time_dim: time + surface_pressure: PS + latent_heat_flux: LHFLX + precip_rate: PRECT + precipitable_water_path: precipitable_water_path + pressure_thickness: pressure_thickness_of_atmospheric_layer + air_temperature: T + specific_humidity: Q + cloud_water_mixing_ratio: CLDLIQ + cloud_ice_mixing_ratio: CLDICE + graupel_mixing_ratio: None + rain_mixing_ratio: RAINQM + snow_mixing_ratio: SNOWQM + northward_wind: V + eastward_wind: U + hybrid_level_coeffs: + - hyai + - hybi \ No newline at end of file diff --git a/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml b/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml new file mode 100644 index 0000000..b01df38 --- /dev/null +++ b/scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml @@ -0,0 +1,9 @@ +runs: + 2024-07-11-era5-1deg-16layer-1940-2022: "" # no real data source, config only for computing stats +data_output_directory: gs://vcm-ml-intermediate +stats: + output_directory: gs://vcm-ml-intermediate/era5-1deg-16layer-stats-1990-2019 + beaker_dataset: era5-1deg-16layer-stats-1990-2019 + start_date: "1990-01-01" + end_date: "2019-12-31" + data_type: ERA5 diff --git a/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml b/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml new file mode 100644 index 0000000..b200418 --- /dev/null +++ b/scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml @@ -0,0 +1,9 @@ +runs: + 2024-06-20-era5-1deg-8layer-1940-2022: "" # no real data source, config only for computing stats +data_output_directory: gs://vcm-ml-intermediate +stats: + output_directory: gs://vcm-ml-intermediate/2024-06-20-era5-1deg-8layer-stats-1990-2019 + beaker_dataset: era5-1deg-8layer-stats-1990-2019-v2 + start_date: "1990-01-01" + end_date: "2019-12-31" + data_type: ERA5 diff --git a/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml new file mode 100644 index 0000000..ef7ac9e --- /dev/null +++ b/scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml @@ -0,0 +1,81 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0002 + ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0003 + ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0004 +data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats + beaker_dataset: 2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats + start_date: "1990-01-01" + end_date: "2019-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0004" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + - [18, 26] + - [26, 31] + - [31, 36] + - [36, 41] + - [41, 47] + - [47, 53] + - [53, 63] + renaming: + specific_humidity_at_two_meters: Q2m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + fourcastnet_vanilla.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - TMP2m + - UGRD10m + - VGRD10m + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture + - specific_humidity_at_two_meters + encoded_surface_type.zarr: + - land_fraction + - ocean_fraction + - sea_ice_fraction diff --git a/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml new file mode 100644 index 0000000..d329622 --- /dev/null +++ b/scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml @@ -0,0 +1,81 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 + ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0003 + ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0004 +data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats + beaker_dataset: 2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats + start_date: "1990-01-01" + end_date: "2019-12-31" + exclude_runs: + - "ic_0004" + data_type: FV3GFS +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + - [18, 26] + - [26, 31] + - [31, 36] + - [36, 41] + - [41, 47] + - [47, 53] + - [53, 63] + renaming: + specific_humidity_at_two_meters: Q2m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + fourcastnet_vanilla.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - TMP2m + - UGRD10m + - VGRD10m + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture + - specific_humidity_at_two_meters + encoded_surface_type.zarr: + - land_fraction + - ocean_fraction + - sea_ice_fraction diff --git a/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml new file mode 100644 index 0000000..fdaff68 --- /dev/null +++ b/scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml @@ -0,0 +1,77 @@ +runs: + ic_0011: gs://vcm-ml-raw-flexible-retention/2023-08-03-C48-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0011_2021010100 +data_output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats + beaker_dataset: 2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats + start_date: "2021-01-01" + end_date: "2030-12-31" + data_type: FV3GFS +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + - [18, 26] + - [26, 31] + - [31, 36] + - [36, 41] + - [41, 47] + - [47, 53] + - [53, 63] + renaming: + specific_humidity_at_two_meters: Q2m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + fourcastnet_vanilla.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - TMP2m + - UGRD10m + - VGRD10m + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture + - specific_humidity_at_two_meters + encoded_surface_type.zarr: + - land_fraction + - ocean_fraction + - sea_ice_fraction + roundtrip_fraction_kept: 0.65 diff --git a/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml new file mode 100644 index 0000000..08a8379 --- /dev/null +++ b/scripts/data_process/configs/fv3gfs-ensemble-1deg-8layer.yaml @@ -0,0 +1,88 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0002 + ic_0003: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0003 + ic_0004: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0004 + ic_0005: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0005 + ic_0006: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0006 + ic_0007: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0007 + ic_0008: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0008 + ic_0009: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0009 + ic_0010: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0010 + ic_0011: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0011 +data_output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset-stats + beaker_dataset: 2023-08-09-vertically-resolved-1deg-fme-ensemble-dataset-stats + start_date: "2021-01-01" + end_date: "2030-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0011" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + - [18, 26] + - [26, 31] + - [31, 36] + - [36, 41] + - [41, 47] + - [47, 53] + - [53, 63] + renaming: + specific_humidity_at_two_meters: Q2m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + fourcastnet_vanilla.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - TMP2m + - UGRD10m + - VGRD10m + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture + - specific_humidity_at_two_meters + encoded_surface_type.zarr: + - land_fraction + - ocean_fraction + - sea_ice_fraction diff --git a/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml new file mode 100644 index 0000000..22a0c2b --- /dev/null +++ b/scripts/data_process/configs/fv3gfs-ensemble-4deg-8layer.yaml @@ -0,0 +1,88 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 + ic_0003: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0003 + ic_0004: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0004 + ic_0005: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0005 + ic_0006: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0006 + ic_0007: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0007 + ic_0008: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0008 + ic_0009: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0009 + ic_0010: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0010 + ic_0011: gs://vcm-ml-raw-flexible-retention/2023-07-08-C96-FME-reference-ensemble/regridded-zarrs/gaussian_grid_45_by_90/ic_0011 +data_output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset-stats + beaker_dataset: 2023-08-09-vertically-resolved-4deg-fme-ensemble-dataset-stats + start_date: "2021-01-01" + end_date: "2030-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0011" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + - [18, 26] + - [26, 31] + - [31, 36] + - [36, 41] + - [41, 47] + - [47, 53] + - [53, 63] + renaming: + specific_humidity_at_two_meters: Q2m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + fourcastnet_vanilla.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - TMP2m + - UGRD10m + - VGRD10m + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture + - specific_humidity_at_two_meters + encoded_surface_type.zarr: + - land_fraction + - ocean_fraction + - sea_ice_fraction diff --git a/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml b/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml new file mode 100644 index 0000000..0f56b31 --- /dev/null +++ b/scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml @@ -0,0 +1,119 @@ +runs: + run_directory: /mntdata +stats: + output_directory: /mntdata/2024-08-21-healpix-era5-dataset + beaker_dataset: 2024-08-21-healpix-era5-dataset-stats + start_date: "1990-01-01" + end_date: "2019-12-31" + data_type: ERA5 +data_output_directory: /mntdata/2024-08-21-healpix-era5-dataset +dataset_computation: + n_split: 400 + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 18] + standard_names: + face_dim: "face" + chunking: + face_dim: -1 + variable_sources: + 2024-06-20-era5-1deg-8layer-1940-2022.zarr: + - DLWRFsfc + - DPT2m + - DSWRFsfc + - DSWRFtoa + - HGTsfc + - LHTFLsfc + - PRATEsfc + - PRESsfc + - Q200 + - Q2m + - Q500 + - Q850 + - SHTFLsfc + - TMP200 + - TMP2m + - TMP500 + - TMP850 + - UGRD10m + - UGRD200 + - UGRD500 + - UGRD850 + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - VGRD10m + - VGRD200 + - VGRD500 + - VGRD850 + - air_temperature_0 + - air_temperature_1 + - air_temperature_2 + - air_temperature_3 + - air_temperature_4 + - air_temperature_5 + - air_temperature_6 + - air_temperature_7 + - ak_0 + - ak_1 + - ak_2 + - ak_3 + - ak_4 + - ak_5 + - ak_6 + - ak_7 + - ak_8 + - bk_0 + - bk_1 + - bk_2 + - bk_3 + - bk_4 + - bk_5 + - bk_6 + - bk_7 + - bk_8 + - eastward_wind_0 + - eastward_wind_1 + - eastward_wind_2 + - eastward_wind_3 + - eastward_wind_4 + - eastward_wind_5 + - eastward_wind_6 + - eastward_wind_7 + - h1000 + - h200 + - h250 + - h300 + - h500 + - h700 + - h850 + - land_fraction + - latitude + - longitude + - northward_wind_0 + - northward_wind_1 + - northward_wind_2 + - northward_wind_3 + - northward_wind_4 + - northward_wind_5 + - northward_wind_6 + - northward_wind_7 + - ocean_fraction + - sea_ice_fraction + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - specific_total_water_0 + - specific_total_water_1 + - specific_total_water_2 + - specific_total_water_3 + - specific_total_water_4 + - specific_total_water_5 + - specific_total_water_6 + - specific_total_water_7 + - surface_temperature + - tendency_of_total_water_path_due_to_advection + - time + - total_column_water_vapour \ No newline at end of file diff --git a/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml b/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml new file mode 100644 index 0000000..9342e6f --- /dev/null +++ b/scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml @@ -0,0 +1,71 @@ +runs: + 2024-09-20-cm4-1deg-8layer-trial-run: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/regridded-zarrs/gaussian_grid_180_by_360/trial-run +data_output_directory: gs://vcm-ml-intermediate +stats: + output_directory: gs://vcm-ml-intermediate/2024-09-20-cm4-1deg-8layer-trial-run-stats + start_date: "0151-01-01" + end_date: "0159-01-01" + data_type: CM4 + beaker_dataset: not-used +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + # computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2024-08-13-pre-industiral-CM4-eda/2024-08-28-AM4-vertical-indices.ipynb + - [0, 7] + - [7, 10] + - [10, 13] + - [13, 16] + - [16, 18] + - [18, 22] + - [22, 25] + - [25, 33] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - eastward_surface_wind_stress + - northward_surface_wind_stress + - surface_evaporation_rate + - total_energy + - total_frozen_precipitation_rate + full_state.zarr: + # 2D vars + - HGTsfc # static + - PRESsfc + - surface_temperature + - air_temperature_at_two_meters + - specific_humidity_at_two_meters + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + # 3D vars: + - air_temperature + - specific_humidity # water species + - cloud_water_mixing_ratio # water species + - cloud_ice_mixing_ratio # water species + - eastward_wind + - northward_wind + land_static.zarr: + - land_fraction + full_state_land.zarr: + - column_soil_moisture + full_state_ice.zarr: + - sea_ice_fraction + standard_names: + longitude_dim: lon + latitude_dim: lat + graupel_mixing_ratio: None + rain_mixing_ratio: None + snow_mixing_ratio: None + precipitable_water_path: None diff --git a/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml new file mode 100644 index 0000000..24103f8 --- /dev/null +++ b/scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml @@ -0,0 +1,90 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 +data_output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats + beaker_dataset: 2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats + start_date: "1940-01-01" + end_date: "2021-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0002" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - air_temperature_at_two_meters + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - snow_cover_fraction + - specific_humidity_at_two_meters + - land_fraction + - ocean_fraction + - sea_ice_fraction + - UGRD200 + - VGRD200 + - TMP200 + - RH200 + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml new file mode 100644 index 0000000..c946ed0 --- /dev/null +++ b/scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml @@ -0,0 +1,91 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0002 +data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats + beaker_dataset: 2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats + start_date: "1940-01-01" + end_date: "2021-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0002" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - UGRD1000 + - VGRD500 + - VGRD850 + - VGRD1000 + - h50 + - h500 + - h850 + - h1000 + - air_temperature_at_two_meters + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - snow_cover_fraction + - specific_humidity_at_two_meters + - land_fraction + - ocean_fraction + - sea_ice_fraction + UGRD200_VGRD200_TMP200_RH200.zarr: + - UGRD200 + - VGRD200 + - TMP200 + - RH200 + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml new file mode 100644 index 0000000..a2ce8e6 --- /dev/null +++ b/scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml @@ -0,0 +1,90 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 +data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats + beaker_dataset: 2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats + start_date: "1940-01-01" + end_date: "2021-12-31" + data_type: FV3GFS + exclude_runs: + - "ic_0002" +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - PRESsfc + - HGTsfc + - RH500 + - RH850 + - TMP500 + - TMP850 + - UGRD500 + - UGRD850 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + - air_temperature_at_two_meters + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - column_soil_moisture + - snow_cover_fraction + - specific_humidity_at_two_meters + - land_fraction + - ocean_fraction + - sea_ice_fraction + UGRD200_VGRD200_TMP200_RH200.zarr: + - UGRD200 + - VGRD200 + - TMP200 + - RH200 + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml b/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml new file mode 100644 index 0000000..fd12e65 --- /dev/null +++ b/scripts/data_process/configs/shield-c24-ensemble-4deg-8layer.yaml @@ -0,0 +1,89 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 + ic_0002: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 + ic_0003: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0003 + ic_0004: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0004 + ic_0005: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0005 + ic_0006: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0006 + ic_0007: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0007 + ic_0008: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0008 + ic_0009: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0009 + ic_0010: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0010 + ic_0011: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0011 + ic_0012: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0012 + ic_0013: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0013 + ic_0014: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0014 + ic_0015: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0015 + ic_0016: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0016 + ic_0017: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0017 + ic_0018: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0018 + ic_0019: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0019 + ic_0020: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0020 + ic_0021: gs://vcm-ml-raw-flexible-retention/2024-03-08-climSST-C24-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0021 +data_output_directory: gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset-stats + beaker_dataset: 2024-04-05-vertically-resolved-4deg-c24-shield-fme-ensemble-dataset-stats + start_date: "2021-01-01" + end_date: "2030-12-31" + exclude_runs: + - "ic_0021" + data_type: FV3GFS +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters \ No newline at end of file diff --git a/scripts/data_process/configs/shield-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-c96-4deg-8layer.yaml new file mode 100644 index 0000000..23a5116 --- /dev/null +++ b/scripts/data_process/configs/shield-c96-4deg-8layer.yaml @@ -0,0 +1,68 @@ +runs: + ic_0001: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/repeating-sst +data_output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats + beaker_dataset: 2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats + # start_date: "2035-01-01" # start of run + start_date: "2036-01-01" # we exclude just the first year so we can use it as an initial condition + end_date: "2060-12-31" # end of run + data_type: FV3GFS +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters \ No newline at end of file diff --git a/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml new file mode 100644 index 0000000..159d23c --- /dev/null +++ b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml @@ -0,0 +1,94 @@ +runs: + abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-2xCO2 + abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-3xCO2 + abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-4xCO2 +data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats + beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + - abrupt-2xCO2 + - abrupt-3xCO2 + - abrupt-4xCO2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml new file mode 100644 index 0000000..746f873 --- /dev/null +++ b/scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml @@ -0,0 +1,94 @@ +runs: + abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2 + abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2 + abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2 +data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats + beaker_dataset: 2024-08-14-vertically-resolved-4deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + - abrupt-2xCO2 + - abrupt-3xCO2 + - abrupt-4xCO2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml new file mode 100644 index 0000000..ffbffa0 --- /dev/null +++ b/scripts/data_process/configs/shield-som-c24-4deg-8layer.yaml @@ -0,0 +1,140 @@ +runs: + 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001 + 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002 + 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003 + 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004 + 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005 + 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001 + 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002 + 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003 + 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004 + 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005 + 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001 + 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002 + 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003 + 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004 + 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005 + 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001 + 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002 + 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003 + 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004 + 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005 + abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2 + abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2 + abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2 + increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C24-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2 +data_output_directory: gs://vcm-ml-intermediate/2024-07-17-vertically-resolved-4deg-c24-shield-som-baseline-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-17-vertically-resolved-4deg-c24-shield-som-ensemble-fme-dataset-stats + beaker_dataset: 2024-07-17-vertically-resolved-4deg-fme-c24-shield-som-ensemble-dataset-stats + # These datasets already exclude the one-year divergence period for each + # ensemble member, so we can use data from all times when computing the + # stats. + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + # Exclude all runs here since we are not interested in these for ML training + - 1xCO2-ic_0001 + - 1xCO2-ic_0002 + - 1xCO2-ic_0003 + - 1xCO2-ic_0004 + - 1xCO2-ic_0005 + - 2xCO2-ic_0001 + - 2xCO2-ic_0002 + - 2xCO2-ic_0003 + - 2xCO2-ic_0004 + - 2xCO2-ic_0005 + - 3xCO2-ic_0001 + - 3xCO2-ic_0002 + - 3xCO2-ic_0003 + - 3xCO2-ic_0004 + - 3xCO2-ic_0005 + - 4xCO2-ic_0001 + - 4xCO2-ic_0002 + - 4xCO2-ic_0003 + - 4xCO2-ic_0004 + - 4xCO2-ic_0005 + - abrupt-2xCO2 + - abrupt-3xCO2 + - abrupt-4xCO2 + - increasing-co2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml new file mode 100644 index 0000000..fe93ebe --- /dev/null +++ b/scripts/data_process/configs/shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml @@ -0,0 +1,140 @@ +runs: + 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001 + 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002 + 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003 + 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004 + 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005 + 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001 + 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002 + 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003 + 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004 + 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005 + 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001 + 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002 + 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003 + 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004 + 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005 + 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001 + 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002 + 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003 + 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004 + 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005 + abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2 + abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2 + abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2 + increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-SOM-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2 +data_output_directory: gs://vcm-ml-intermediate/2024-11-12-vertically-resolved-4deg-c24-shield-som-tuned-cdmbgwd-baseline-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-11-12-vertically-resolved-4deg-c24-shield-som-tuned-cdmbgwd-ensemble-fme-dataset-stats + beaker_dataset: 2024-11-12-vertically-resolved-4deg-fme-c24-shield-som-tuned-cdmbgwd-ensemble-dataset-stats + # These datasets already exclude the one-year divergence period for each + # ensemble member, so we can use data from all times when computing the + # stats. + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + # Exclude all runs here since we are not interested in these for ML training + - 1xCO2-ic_0001 + - 1xCO2-ic_0002 + - 1xCO2-ic_0003 + - 1xCO2-ic_0004 + - 1xCO2-ic_0005 + - 2xCO2-ic_0001 + - 2xCO2-ic_0002 + - 2xCO2-ic_0003 + - 2xCO2-ic_0004 + - 2xCO2-ic_0005 + - 3xCO2-ic_0001 + - 3xCO2-ic_0002 + - 3xCO2-ic_0003 + - 3xCO2-ic_0004 + - 3xCO2-ic_0005 + - 4xCO2-ic_0001 + - 4xCO2-ic_0002 + - 4xCO2-ic_0003 + - 4xCO2-ic_0004 + - 4xCO2-ic_0005 + - abrupt-2xCO2 + - abrupt-3xCO2 + - abrupt-4xCO2 + - increasing-co2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml new file mode 100644 index 0000000..8bb3618 --- /dev/null +++ b/scripts/data_process/configs/shield-som-ensemble-c96-1deg-8layer.yaml @@ -0,0 +1,121 @@ +runs: + 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0001 + 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0002 + 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0003 + 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0004 + 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-ic_0005 + 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0001 + 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0002 + 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0003 + 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0004 + 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-ic_0005 + 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0001 + 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0002 + 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0003 + 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0004 + 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-ic_0005 + 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0001 + 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0002 + 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0003 + 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0004 + 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-ic_0005 +data_output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-1deg-c96-shield-som-ensemble-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-1deg-c96-shield-som-ensemble-fme-dataset-stats + beaker_dataset: 2024-07-09-vertically-resolved-1deg-fme-c96-shield-som-ensemble-dataset-stats + # These datasets already exclude the one-year divergence period for each + # ensemble member, so we can use data from all times when computing the + # stats. + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + # In sample validation data + - 1xCO2-ic_0005 + - 2xCO2-ic_0005 + - 4xCO2-ic_0005 + # Out of sample equilibrium climate data + - 3xCO2-ic_0001 + - 3xCO2-ic_0002 + - 3xCO2-ic_0003 + - 3xCO2-ic_0004 + - 3xCO2-ic_0005 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml new file mode 100644 index 0000000..1df2bfe --- /dev/null +++ b/scripts/data_process/configs/shield-som-ensemble-c96-4deg-8layer.yaml @@ -0,0 +1,127 @@ +runs: + 1xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0001 + 1xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0002 + 1xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0003 + 1xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0004 + 1xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/1xCO2-ic_0005 + 2xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0001 + 2xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0002 + 2xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0003 + 2xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0004 + 2xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/2xCO2-ic_0005 + 3xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0001 + 3xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0002 + # For expediency we did not regrid these ensemble members to 4 degree + # resolution. Where needed, we regrid after computing a time and ensemble mean + # with the one degree data. + # 3xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0003 + # 3xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0004 + # 3xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/3xCO2-ic_0005 + 4xCO2-ic_0001: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0001 + 4xCO2-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0002 + 4xCO2-ic_0003: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0003 + 4xCO2-ic_0004: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0004 + 4xCO2-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/4xCO2-ic_0005 +data_output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-4deg-c96-shield-som-ensemble-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-09-vertically-resolved-4deg-c96-shield-som-ensemble-fme-dataset-stats + beaker_dataset: 2024-07-09-vertically-resolved-4deg-fme-c96-shield-som-ensemble-dataset-stats + # These datasets already exclude the one-year divergence period for each + # ensemble member, so we can use data from all times when computing the + # stats. + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + # In sample validation data + - 1xCO2-ic_0005 + - 2xCO2-ic_0005 + - 4xCO2-ic_0005 + # Out of sample equilibrium climate data + - 3xCO2-ic_0001 + - 3xCO2-ic_0002 + # For expediency we did not regrid these ensemble members to 4 degree + # resolution. Where needed, we regrid after computing a time and ensemble + # mean with the one degree data. + # - 3xCO2-ic_0003 + # - 3xCO2-ic_0004 + # - 3xCO2-ic_0005 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml new file mode 100644 index 0000000..843f433 --- /dev/null +++ b/scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml @@ -0,0 +1,90 @@ +runs: + increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/increasing-CO2 +data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset-stats + beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-increasing-co2-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + - increasing-CO2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml b/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml new file mode 100644 index 0000000..d82256e --- /dev/null +++ b/scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml @@ -0,0 +1,90 @@ +runs: + increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2 +data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset-stats + beaker_dataset: 2024-07-16-vertically-resolved-4deg-fme-c96-shield-som-increasing-co2-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + - increasing-CO2 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml new file mode 100644 index 0000000..19d1366 --- /dev/null +++ b/scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml @@ -0,0 +1,90 @@ +runs: + radiation-multi-call: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/radiation-multi-call +data_output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset-stats + beaker_dataset: 2024-10-22-vertically-resolved-1deg-fme-c96-radiation-multi-call-co2-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + - radiation-multi-call +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml b/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml new file mode 100644 index 0000000..b6b5c64 --- /dev/null +++ b/scripts/data_process/configs/shield-som-spin-up-c96-1deg-8layer.yaml @@ -0,0 +1,97 @@ +runs: + 1xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/1xCO2-spin-up-ic_0005 + 2xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/2xCO2-spin-up-ic_0005 + 3xCO2-spin-up-ic_0002: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/3xCO2-spin-up-ic_0002 + 4xCO2-spin-up-ic_0005: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/4xCO2-spin-up-ic_0005 +data_output_directory: gs://vcm-ml-intermediate/2024-08-15-vertically-resolved-1deg-c96-shield-som-ensemble-spin-up-fme-dataset +stats: + output_directory: gs://vcm-ml-intermediate/2024-08-15-vertically-resolved-1deg-c96-shield-som-ensemble-spin-up-fme-dataset-stats + beaker_dataset: 2024-08-15-vertically-resolved-1deg-fme-c96-shield-som-ensemble-spin-up-dataset-stats + start_date: null + end_date: null + data_type: FV3GFS + exclude_runs: + # In sample validation data + - 1xCO2-spin-up-ic_0005 + - 2xCO2-spin-up-ic_0005 + - 3xCO2-spin-up-ic_0002 + - 4xCO2-spin-up-ic_0005 +dataset_computation: + reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc + vertical_coarsening_indices: + - [0, 11] + - [11, 21] + - [21, 30] + - [30, 39] + - [39, 49] + - [49, 58] + - [58, 67] + - [67, 79] + renaming: + specific_humidity_at_two_meters: Q2m + air_temperature_at_two_meters: TMP2m + eastward_wind_at_ten_meters: UGRD10m + northward_wind_at_ten_meters: VGRD10m + variable_sources: + fluxes_2d.zarr: + - PRATEsfc + - LHTFLsfc + - SHTFLsfc + - DLWRFsfc + - DSWRFsfc + - DSWRFtoa + - ULWRFsfc + - ULWRFtoa + - USWRFsfc + - USWRFtoa + - precipitable_water_path + - GRAUPELsfc + - ICEsfc + - SNOWsfc + full_state.zarr: + - surface_temperature + - air_temperature + - specific_humidity + - cloud_water_mixing_ratio + - cloud_ice_mixing_ratio + - graupel_mixing_ratio + - rain_mixing_ratio + - snow_mixing_ratio + - northward_wind + - eastward_wind + - pressure_thickness_of_atmospheric_layer + - PRESsfc + - HGTsfc + - column_soil_moisture + - soil_moisture_0 + - soil_moisture_1 + - soil_moisture_2 + - soil_moisture_3 + - land_fraction + - ocean_fraction + - sea_ice_fraction + - specific_humidity_at_two_meters + - air_temperature_at_two_meters + - northward_wind_at_ten_meters + - eastward_wind_at_ten_meters + - RH200 + - RH500 + - RH850 + - TMP200 + - TMP500 + - TMP850 + - UGRD200 + - UGRD500 + - UGRD850 + - VGRD200 + - VGRD500 + - VGRD850 + - h50 + - h200 + - h500 + - h850 + ocean_forcing.zarr: + - prescribed_mixed_layer_depth + - prescribed_qflux + scalar.zarr: + - global_mean_co2 diff --git a/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh b/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh new file mode 100755 index 0000000..638901a --- /dev/null +++ b/scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# This script launches N_IC jobs to convert all GCS zarr data to local monthly netCDFs + +while [[ "$#" -gt 0 ]] +do case $1 in + --input-url) BASE_INPUT_URL="$2" + shift;; + --n-ic) N_IC=$2 + shift;; + --output-dir) BASE_OUTPUT_DIR="$2" + shift;; + --start-date) START_DATE="$2" + shift;; + --end-date) END_DATE="$2" + shift;; + *) echo "Unknown parameter passed: $1" + exit 1;; +esac +shift +done + +if [[ -z "${BASE_INPUT_URL}" ]] +then + echo "Option --input-url missing" + exit 1; +elif [[ -z "${N_IC}" ]] +then + echo "Option --n-ic missing" + exit 1; +elif [[ -z "${BASE_OUTPUT_DIR}" ]] +then + echo "Option --output-dir missing" + exit 1; +elif [[ -z "${START_DATE}" ]] +then + echo "Option --start-date missing" + exit 1; +elif [[ -z "${END_DATE}" ]] +then + echo "Option --end-date missing" + exit 1; +fi + + + +for IC in $(seq 1 $(( N_IC ))); do + IC_STR=$(printf "%04d" ${IC}) + INPUT_URL=${BASE_INPUT_URL}/ic_${IC_STR}.zarr + OUTPUT_DIR=${BASE_OUTPUT_DIR}/ic_${IC_STR} + python convert_to_monthly_netcdf.py \ + $INPUT_URL \ + $OUTPUT_DIR \ + --start-date $START_DATE \ + --end-date $END_DATE & +done diff --git a/scripts/data_process/earth2grid.Dockerfile b/scripts/data_process/earth2grid.Dockerfile new file mode 100644 index 0000000..e8fd37c --- /dev/null +++ b/scripts/data_process/earth2grid.Dockerfile @@ -0,0 +1,11 @@ +FROM nvcr.io/nvidia/pytorch:23.08-py3 + +# Clone and install earth2grid if not already installed +RUN PACKAGE=earth2grid && \ + if ! pip show "$PACKAGE" &>/dev/null; then \ + git clone https://github.com/NVlabs/earth2grid.git && \ + cd earth2grid && \ + pip install --no-build-isolation . && \ + cd .. && \ + rm -rf earth2grid; \ + fi \ No newline at end of file diff --git a/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh b/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh new file mode 100755 index 0000000..a440b5d --- /dev/null +++ b/scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +set -e + +while [[ "$#" -gt 0 ]] +do case $1 in + --input-url) INPUT_URL="$2" + shift;; + --start-date) START_DATE="$2" + shift;; + --end-date) END_DATE="$2" + shift;; + --name) DATASET_NAME="$2" + shift;; + --desc) DATASET_DESC="$2" + shift;; + --script-flags) SCRIPT_FLAGS="$2" + shift;; + *) echo "Unknown parameter passed: $1" + exit 1;; +esac +shift +done + +if [[ -z "${INPUT_URL}" ]] +then + echo "Option --input-url missing" + exit 1; +elif [[ -z "${START_DATE}" ]] +then + echo "Option --start-date missing" + exit 1; +elif [[ -z "${END_DATE}" ]] +then + echo "Option --end-date missing" + exit 1; +elif [[ -z "${DATASET_NAME}" ]] +then + echo "Option --dataset-name missing" + exit 1; +fi + +OUTPUT_DIR="/tmp/$(uuidgen)" + +python get_stats.py \ + $INPUT_URL \ + ${OUTPUT_DIR} \ + --start-date $START_DATE \ + --end-date $END_DATE ${SCRIPT_FLAGS} + +beaker dataset create ${OUTPUT_DIR} \ + --name ${DATASET_NAME} --desc "${DATASET_DESC}" + +rm -rf ${OUTPUT_DIR} diff --git a/scripts/data_process/generate_datasets_e3smv2.sh b/scripts/data_process/generate_datasets_e3smv2.sh new file mode 100755 index 0000000..056204f --- /dev/null +++ b/scripts/data_process/generate_datasets_e3smv2.sh @@ -0,0 +1,89 @@ +#!/bin/bash -l + +#SBATCH -A m4331 +#SBATCH -q regular +#SBATCH -C cpu +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH -t 01:00:00 +#SBATCH --output=joblogs/%x-%j.out + +while [[ "$#" -gt 0 ]] +do case $1 in + -i|--input-dir) INPUT_DIR="$2" + shift;; + -c|--config) CONFIG="$2" + shift;; + -z|--zarr) ZARR="$2" + shift;; + -o|--output-dir) OUTPUT_DIR="$2" + shift;; + *) echo "Unknown parameter passed: $1" + exit 1;; +esac +shift +done + +if [[ -z "${INPUT_DIR}" ]] +then + echo "Option -i, --input-dir missing" + exit 1; +elif [[ -z "${CONFIG}" ]] +then + echo "Option -c, --config missing" + exit 1; +elif [[ -z "${ZARR}" ]] +then + echo "Option -z, --zarr missing" + exit 1; +elif [[ -z "${OUTPUT_DIR}" ]] +then + echo "Option -o, --output-dir missing" + exit 1; +fi + +# output dir should be somewhere on $SCRATCH, even if the intention is to send +# the data to CFS or some external data store +mkdir -p $OUTPUT_DIR + +# stripe_small is recommended for files of size 1-10 GB on Perlmutter's Lustre +# scratch filesystem and stripes across 8 OSTs +# see https://docs.nersc.gov/performance/io/lustre/#nersc-file-striping-recommendations +stripe_small $OUTPUT_DIR + +# NOTE: assumes you've already created the fv3net conda env. See +# https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310 +source activate fv3net + +set -xe + +# create the zarr from E3SMv2 .nc files +python -u compute_dataset_e3smv2.py --n-workers=16 --config=${CONFIG} \ + -i ${INPUT_DIR} -o ${ZARR} + +# Train on first year (intended for training) +python -u convert_to_monthly_netcdf.py \ + ${ZARR} \ + ${OUTPUT_DIR}/traindata \ + --start-date 1970-01-01 \ + --end-date 1970-12-31 \ + --nc-format NETCDF4 + +# Validation on next 6 months +python -u convert_to_monthly_netcdf.py \ + ${ZARR} \ + ${OUTPUT_DIR}/validdata \ + --start-date 1971-01-01 \ + --end-date 1971-05-31 \ + --nc-format NETCDF4 + +# Final 6 months for preditiondata reference +python -u convert_to_monthly_netcdf.py \ + ${ZARR} \ + ${OUTPUT_DIR}/predictiondata \ + --start-date 1971-06-01 \ + --end-date 1971-12-31 \ + --nc-format NETCDF4 + +# compute all stats on training data +python -u get_stats.py ${CONFIG} 0 diff --git a/scripts/data_process/get_stats.py b/scripts/data_process/get_stats.py index 2a374df..448e787 100644 --- a/scripts/data_process/get_stats.py +++ b/scripts/data_process/get_stats.py @@ -38,12 +38,13 @@ "FV3GFS": ["time", "grid_xt", "grid_yt"], "E3SMV2": ["time", "lat", "lon"], "ERA5": ["time", "latitude", "longitude"], + "CM4": ["time", "lat", "lon"], } def add_history_attrs(ds, input_zarr, start_date, end_date, n_samples): ds.attrs["history"] = ( - "Created by ace/fv3gfs_data_process/get_stats.py. INPUT_ZARR:" + "Created by full-model/fv3gfs_data_process/get_stats.py. INPUT_ZARR:" f" {input_zarr}, START_DATE: {start_date}, END_DATE: {end_date}." ) ds.attrs["input_samples"] = n_samples @@ -64,10 +65,11 @@ def copy(source: str, destination: str): @dataclasses.dataclass class StatsConfig: output_directory: str - data_type: Literal["FV3GFS", "E3SMV2", "ERA5"] + data_type: Literal["FV3GFS", "E3SMV2", "ERA5", "CM4"] exclude_runs: List[str] = dataclasses.field(default_factory=list) start_date: Optional[str] = None end_date: Optional[str] = None + beaker_dataset: Optional[str] = None @dataclasses.dataclass diff --git a/scripts/data_process/test_config.py b/scripts/data_process/test_config.py new file mode 100644 index 0000000..aa24d30 --- /dev/null +++ b/scripts/data_process/test_config.py @@ -0,0 +1,27 @@ +import os + +import dacite +import pytest +import yaml +from combine_stats import Config as CombineStatsConfig +from get_stats import Config as GetStatsConfig +from upload_stats import Config as UploadStatsConfig + +DIRNAME = os.path.abspath(os.path.dirname(__file__)) +# list files in DIRNAME/config +CONFIG_YAMLS = [ + os.path.join(DIRNAME + "/configs", f) + for f in os.listdir(DIRNAME + "/configs") + if f.endswith(".yaml") +] + + +@pytest.mark.parametrize( + "filename", + CONFIG_YAMLS, +) +@pytest.mark.parametrize("cls", [GetStatsConfig, UploadStatsConfig, CombineStatsConfig]) +def test_get_stats_valid(filename, cls): + with open(filename, "r") as f: + config_data = yaml.load(f, Loader=yaml.CLoader) + dacite.from_dict(data_class=cls, data=config_data) diff --git a/scripts/data_process/upload_stats.py b/scripts/data_process/upload_stats.py new file mode 100644 index 0000000..863ac91 --- /dev/null +++ b/scripts/data_process/upload_stats.py @@ -0,0 +1,89 @@ +import dataclasses +import shutil +import tempfile +from typing import Dict, List, Optional + +import click +import dacite +import fsspec +import yaml + + +def copy(source: str, destination: str): + """Copy between any two 'filesystems'. Do not use for large files. + + Args: + source: Path to source file/object. + destination: Path to destination. + """ + with fsspec.open(source) as f_source: + with fsspec.open(destination, "wb") as f_destination: + shutil.copyfileobj(f_source, f_destination) + + +@dataclasses.dataclass +class StatsConfig: + output_directory: str + beaker_dataset: str + exclude_runs: List[str] = dataclasses.field(default_factory=list) + start_date: Optional[str] = None + end_date: Optional[str] = None + + +@dataclasses.dataclass +class Config: + runs: Dict[str, str] + data_output_directory: str + stats: StatsConfig + + +@click.command() +@click.argument("config_yaml", type=str) +def main(config_yaml: str): + """ + Combine statistics for the data processing pipeline. + + Arguments: + config_yaml -- Path to the configuration file for the data processing pipeline. + """ + # imported here so we don't need to install beaker for the tests + from beaker import Beaker + + with open(config_yaml, "r") as f: + config_data = yaml.load(f, Loader=yaml.CLoader) + config = dacite.from_dict(data_class=Config, data=config_data) + + stats_combined_dir = config.stats.output_directory + "/combined/" + beaker = Beaker.from_env() + with tempfile.TemporaryDirectory() as tmpdir: + for filename in ( + "centering.nc", + "scaling-full-field.nc", + "scaling-residual.nc", + "time-mean.nc", + ): + copy(stats_combined_dir + filename, tmpdir + "/" + filename) + runs = [run for run in config.runs if run not in config.stats.exclude_runs] + run_names = ", ".join(runs) + if config.stats.start_date is None: + start = "start of run" + else: + start = config.stats.start_date + if config.stats.end_date is None: + end = "end of run" + else: + end = config.stats.end_date + beaker.dataset.create( + config.stats.beaker_dataset, + tmpdir, + workspace="ai2/ace", + description=( + "Coefficients for normalization for data " + f"{config.data_output_directory} runs {run_names}. " + f"Computed from {start} to {end}." + ), + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py b/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py index 0826b98..33fa055 100644 --- a/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py +++ b/scripts/era5/netcdf_to_zarr/netcdf_to_zarr_pipeline.py @@ -223,7 +223,7 @@ def create_record( def create_record_local( - item: Tuple[pd.Timestamp, pd.Timestamp, str, str, str] + item: Tuple[pd.Timestamp, pd.Timestamp, str, str, str], ) -> Generator[Tuple[xbeam.Key, xr.Dataset], None, None]: start_time, end_time, variable, category, path = item with tempfile.TemporaryDirectory() as tmpdir: diff --git a/scripts/era5/pipeline/xr-beam-pipeline.py b/scripts/era5/pipeline/xr-beam-pipeline.py index bfdd35f..41fbf12 100644 --- a/scripts/era5/pipeline/xr-beam-pipeline.py +++ b/scripts/era5/pipeline/xr-beam-pipeline.py @@ -2,6 +2,7 @@ import datetime import logging import os +from typing import Sequence import apache_beam as beam import metview @@ -61,8 +62,12 @@ def grid_attribute_fix(ds, names_to_fix, reference_name): URL_GOOGLE_LATLON = ( "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" ) +# following dataset was manually generated in https://github.com/ai2cm/explore/blob/master/oliwm/2024-07-31-generate-ERA5-co2.ipynb # noqa: E501 +URL_CO2 = "gs://vcm-ml-raw-flexible-retention/2024-11-11-co2-annual-mean-for-era5.zarr" -URL_NCAR_ERA5 = "gs://vcm-ml-intermediate/2024-05-17-era5-025deg-2D-variables-from-NCAR-as-zarr" # noqa: E501 +URL_NCAR_ERA5 = ( + "gs://vcm-ml-intermediate/2024-05-17-era5-025deg-2D-variables-from-NCAR-as-zarr" # noqa: E501 +) URL_INVARIANT = f"{URL_NCAR_ERA5}/e5.oper.invariant.zarr" URL_SURFACE_ANALYSIS_LATLON = f"{URL_NCAR_ERA5}/e5.oper.an.sfc.zarr" URL_MEAN_FLUX = f"{URL_NCAR_ERA5}/e5.oper.fc.sfc.meanflux.zarr" @@ -198,13 +203,6 @@ def grid_attribute_fix(ds, names_to_fix, reference_name): RENAME_Z_PRES = { f"geopotential_{p}": f"h{p}" for p in OUTPUT_PRESSURE_LEVELS_GEOPOTENTIAL } -RENAME_ETC = { - "skt": "surface_temperature", - "t2m": "TMP2m", - "u10": "UGRD10m", - "v10": "VGRD10m", - "d2m": "DPT2m", -} RENAME_PRESSURE_LEVEL = { **RENAME_Q_PRES, **RENAME_T_PRES, @@ -214,28 +212,26 @@ def grid_attribute_fix(ds, names_to_fix, reference_name): } -def set_nlayer_globals(output_layer_indices): - # set global vars that depend on the output layer indices - global OUTPUT_LAYER_INDICES - OUTPUT_LAYER_INDICES = output_layer_indices - assert OUTPUT_LAYER_INDICES[-1] == N_INPUT_LAYERS - - global N_OUTPUT_LAYERS - N_OUTPUT_LAYERS = len(OUTPUT_LAYER_INDICES) - 1 - - RENAME_Q = {f"q_{i}": f"specific_total_water_{i}" for i in range(N_OUTPUT_LAYERS)} - RENAME_T = {f"t_{i}": f"air_temperature_{i}" for i in range(N_OUTPUT_LAYERS)} - RENAME_U = {f"u_{i}": f"eastward_wind_{i}" for i in range(N_OUTPUT_LAYERS)} - RENAME_V = {f"v_{i}": f"northward_wind_{i}" for i in range(N_OUTPUT_LAYERS)} - - global RENAME_NATIVE - RENAME_NATIVE = { - **RENAME_Q, - **RENAME_T, - **RENAME_U, - **RENAME_V, - **RENAME_ETC, +def _get_native_rename_dict(n_output_layers): + rename_q = {f"q_{i}": f"specific_total_water_{i}" for i in range(n_output_layers)} + rename_t = {f"t_{i}": f"air_temperature_{i}" for i in range(n_output_layers)} + rename_u = {f"u_{i}": f"eastward_wind_{i}" for i in range(n_output_layers)} + rename_v = {f"v_{i}": f"northward_wind_{i}" for i in range(n_output_layers)} + rename_etc = { + "skt": "surface_temperature", + "t2m": "TMP2m", + "u10": "UGRD10m", + "v10": "VGRD10m", + "d2m": "DPT2m", + } + rename_native = { + **rename_q, + **rename_t, + **rename_u, + **rename_v, + **rename_etc, } + return rename_native def _open_zarr(key, sel_indices): @@ -290,6 +286,17 @@ def open_quarter_degree_datasets(indices) -> xr.Dataset: return sfc, invariant +def open_co2_dataset(start_time, end_time) -> xr.Dataset: + co2 = xr.open_zarr(URL_CO2, chunks=None) + ds_start = pd.Timestamp(co2.time.values[0]) + ds_stop = pd.Timestamp(co2.time.values[-1]) + assert start_time >= ds_start, f"CO2 dataset time start out of bounds" + assert end_time <= ds_stop, f"CO2 dataset time stop out of bounds" + co2 = co2.sel(time=slice(start_time, end_time)) + co2 = co2.load() + return co2 + + def _to_dataset(fs: metview.Fieldset) -> xr.Dataset: return fs.to_dataset().load() @@ -375,7 +382,12 @@ def process_pressure_level_data(key, ds, output_grid=DEFAULT_OUTPUT_GRID): return new_key, output -def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset: +def _process_native_data( + ds: xr.Dataset, output_grid: str, output_layer_indices: Sequence[int] +) -> xr.Dataset: + n_output_layers = len(output_layer_indices) - 1 + rename_dict = _get_native_rename_dict(n_output_layers) + xr.set_options(keep_attrs=True) # singleton time dimension interferes with metview ds = ds.squeeze() @@ -405,15 +417,15 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset: thicknesses = _to_dataarray(thicknesses_fs, "pres") for short_name in ["q", "t", "u", "v"]: variable = _to_dataarray(fieldset_gg.select(shortName=short_name), short_name) - for output_index in range(N_OUTPUT_LAYERS): # type: ignore[name-defined] + for output_index in range(n_output_layers): logging.info( f"Computing vertical integral of {short_name} " f"for output layer {output_index}." ) fine_levels = slice( - OUTPUT_LAYER_INDICES[output_index], # type: ignore[name-defined] - OUTPUT_LAYER_INDICES[output_index + 1], # type: ignore[name-defined] + output_layer_indices[output_index], + output_layer_indices[output_index + 1], ) coarse_level_thicknesses = thicknesses.isel(hybrid=fine_levels) total_thickness = coarse_level_thicknesses.sum("hybrid") @@ -455,7 +467,7 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset: output = _adjust_latlon(output) - output = output.rename(RENAME_NATIVE) # type: ignore[name-defined] + output = output.rename(rename_dict) for name, attrs in DESIRED_ATTRS.items(): if name in output: output[name] = output[name].assign_attrs(**attrs) @@ -481,8 +493,13 @@ def _process_native_data(ds: xr.Dataset, output_grid: str) -> xr.Dataset: return output -def process_native_data(key, ds, output_grid=DEFAULT_OUTPUT_GRID): - output = _process_native_data(ds, output_grid) +def process_native_data( + key, + ds, + output_grid=DEFAULT_OUTPUT_GRID, + output_layer_indices=DEFAULT_OUTPUT_LAYER_INDICES, +): + output = _process_native_data(ds, output_grid, output_layer_indices) new_key = key.replace( offsets={"time": key.offsets["time"], "latitude": 0, "longitude": 0}, vars=frozenset(output.keys()), @@ -658,7 +675,9 @@ def process_quarter_degree_data_sfc_an( return new_key, output_ds -def _get_vertical_coordinate(ds: xr.Dataset, name: str) -> xr.Dataset: +def _get_vertical_coordinate( + ds: xr.Dataset, name: str, output_layer_indices: Sequence[int] +) -> xr.Dataset: """Get the ak/bk vertical coordinate on coarse layer interfaces. Assuming that ds[name] is a 3D variable which includes @@ -667,8 +686,8 @@ def _get_vertical_coordinate(ds: xr.Dataset, name: str) -> xr.Dataset: hybrid_sigma_values = ds[name].attrs["GRIB_pv"] ak = hybrid_sigma_values[: N_INPUT_LAYERS + 1] bk = hybrid_sigma_values[N_INPUT_LAYERS + 1 :] - ak_coarse = [ak[i] for i in OUTPUT_LAYER_INDICES] # type: ignore[name-defined] - bk_coarse = [bk[i] for i in OUTPUT_LAYER_INDICES] # type: ignore[name-defined] + ak_coarse = [ak[i] for i in output_layer_indices] + bk_coarse = [bk[i] for i in output_layer_indices] ak_coarse_ds = xr.Dataset({f"ak_{i}": value for i, value in enumerate(ak_coarse)}) bk_coarse_ds = xr.Dataset({f"bk_{i}": value for i, value in enumerate(bk_coarse)}) for name in ak_coarse_ds: @@ -684,10 +703,12 @@ def _make_template( ds_quarter_degree_sfc, ds_quarter_degree_invariant, ds_google_latlon, + ds_co2, ds_akbk, output_chunks, reuse_template, output_grid, + output_layer_indices, ): """Here we (mostly) lazily process the data to make a reference zarr store for the output. This function mirrors what the pipeline does.""" @@ -711,7 +732,9 @@ def _make_template( ds_sfc_an_regridded = _process_quarter_degree_data_sfc_an( ds_quarter_degree_sfc.isel(time=0), ds_invariant_regridded, output_grid ) - ds_native_regridded = _process_native_data(ds_native.isel(time=0), output_grid) + ds_native_regridded = _process_native_data( + ds_native.isel(time=0), output_grid, output_layer_indices + ) ds_google_latlon_regridded = _process_pressure_level_data( ds_google_latlon.isel(time=0), output_grid ) @@ -743,6 +766,7 @@ def _make_template( # land fraction and temporally variable sea ice fraction # this will get written eagerly since it is not chunked template = template.update(inv_fields) + template = template.update(ds_co2) return template, inv_fields @@ -804,9 +828,6 @@ def main(): parser = _get_parser() args, pipeline_args = parser.parse_known_args() - # set globals that depend on output layer indices - set_nlayer_globals(args.output_layer_indices) - # desired start/end of output dataset, inclusive start_time = datetime.datetime.strptime(args.start_time, "%Y-%m-%dT%H:%M:%S") end_time = datetime.datetime.strptime(args.end_time, "%Y-%m-%dT%H:%M:%S") @@ -855,8 +876,9 @@ def main(): sel_indices ) ds_google_latlon = open_google_latlon_dataset(sel_indices) + ds_co2 = open_co2_dataset(start_time, end_time) logging.info("Getting vertical coordinate") - ds_akbk = _get_vertical_coordinate(ds_native, "t") + ds_akbk = _get_vertical_coordinate(ds_native, "t", args.output_layer_indices) logging.info("Generating template") template, ds_pt25deg_inv_regridded = _make_template( @@ -865,10 +887,12 @@ def main(): ds_quarter_degree_sfc, ds_quarter_degree_inv, ds_google_latlon, + ds_co2, ds_akbk, output_chunks, args.reuse_template, args.output_grid, + args.output_layer_indices, ) logging.info("Template finished generating. Starting pipeline.") @@ -914,7 +938,11 @@ def main(): p | "native_DatasetToChunks" >> xbeam.DatasetToChunks(ds_native, chunks={"time": 1}) - | beam.MapTuple(process_native_data, output_grid=args.output_grid) + | beam.MapTuple( + process_native_data, + output_grid=args.output_grid, + output_layer_indices=args.output_layer_indices, + ) | "native_ConsolidateChunks" >> xbeam.ConsolidateChunks(output_chunks) | "native_to_zarr" >> xbeam.ChunksToZarr(args.output_path, template, output_chunks) diff --git a/scripts/manual_backwards_compatibility/.gitignore b/scripts/manual_backwards_compatibility/.gitignore new file mode 100644 index 0000000..1cf850e --- /dev/null +++ b/scripts/manual_backwards_compatibility/.gitignore @@ -0,0 +1 @@ +test_inference_ace2_era5 \ No newline at end of file diff --git a/scripts/manual_backwards_compatibility/ace2-era5.sh b/scripts/manual_backwards_compatibility/ace2-era5.sh new file mode 100755 index 0000000..4768fc9 --- /dev/null +++ b/scripts/manual_backwards_compatibility/ace2-era5.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -e + +# This script can be used to ensure that your currently installed version of the fme +# package can do inference with the published ACE2-ERA5 model. + +# download necessary data +mkdir -p test_inference_ace2_era5 +cd test_inference_ace2_era5 +mkdir -p initial_conditions +mkdir -p forcing_data +wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/ace2_era5_ckpt.tar?download=true -O ace2_era5_ckpt.tar +wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/inference_config.yaml?download=true -O inference_config.yaml +wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/initial_conditions/ic_2020.nc?download=true -O initial_conditions/ic_2020.nc +wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/forcing_data/forcing_2020.nc?download=true -O forcing_data/forcing_2020.nc + +# update config to use relative paths and do a short run +yq e '.n_forward_steps = 50' -i inference_config.yaml +yq e '.forward_steps_in_memory = 5' -i inference_config.yaml +yq e '.checkpoint_path = "ace2_era5_ckpt.tar"' -i inference_config.yaml +yq e '.initial_condition.path = "initial_conditions/ic_2020.nc"' -i inference_config.yaml +yq e '.forcing_loader.dataset.data_path = "forcing_data/"' -i inference_config.yaml + +# run on CPU or CUDA if the latter is available +yq e '.experiment_dir = "output_cpu"' -i inference_config.yaml +python -m fme.ace.inference inference_config.yaml + +# run on MPS. NOTE: this requires torch==2.5 otherwise there are complaints about some of the +# features used by the SFNO architecture. +yq e '.experiment_dir = "output_mps"' -i inference_config.yaml +export FME_USE_MPS=1 +python -m fme.ace.inference inference_config.yaml diff --git a/scripts/monthly_data/test_write_monthly_data.py b/scripts/monthly_data/test_write_monthly_data.py index a27a638..3c2e44c 100644 --- a/scripts/monthly_data/test_write_monthly_data.py +++ b/scripts/monthly_data/test_write_monthly_data.py @@ -1,13 +1,15 @@ import pathlib from typing import List +import pytest import xarray as xr from write_monthly_data import Config, run -from fme.core.data_loading.config import DataLoaderConfig, XarrayDataConfig +from fme.ace.data_loading.config import DataLoaderConfig +from fme.ace.testing import DimSize, DimSizes +from fme.ace.testing.fv3gfs_data import save_nd_netcdf +from fme.core.dataset.config import XarrayDataConfig from fme.core.logging_utils import LoggingConfig -from fme.core.testing import DimSizes -from fme.core.testing.fv3gfs_data import save_2d_netcdf def write_ensemble_dataset( @@ -18,7 +20,7 @@ def write_ensemble_dataset( for i in range(n_members): ensemble_dir = path / f"ic_{i:04d}" ensemble_dir.mkdir(exist_ok=True) - save_2d_netcdf( + save_nd_netcdf( ensemble_dir / "data.nc", dim_sizes, names, @@ -26,12 +28,14 @@ def write_ensemble_dataset( ) -def test_write_monthly_data(tmp_path: pathlib.Path): +def test_write_monthly_data(very_fast_only: bool, tmp_path: pathlib.Path): + if very_fast_only: + pytest.skip("Skipping non-fast tests") all_names = ["a", "b"] + horizontal = [DimSize("grid_yt", 8), DimSize("grid_xt", 4)] dim_sizes = DimSizes( n_time=4 * 60, - n_lat=4, - n_lon=8, + horizontal=horizontal, nz_interface=2, ) n_members = 3 @@ -45,7 +49,7 @@ def test_write_monthly_data(tmp_path: pathlib.Path): data_loader=DataLoaderConfig( dataset=dataset, batch_size=1, - num_data_workers=1, + num_data_workers=0, ), logging=LoggingConfig( log_to_screen=True, diff --git a/scripts/monthly_data/write_monthly_data.py b/scripts/monthly_data/write_monthly_data.py index 39057d7..096415a 100644 --- a/scripts/monthly_data/write_monthly_data.py +++ b/scripts/monthly_data/write_monthly_data.py @@ -1,9 +1,8 @@ import argparse import dataclasses -import datetime import logging import os -from typing import List +from typing import List, Sequence, Tuple import dacite import torch.utils.data @@ -11,24 +10,42 @@ import yaml import fme.core.logging_utils as logging_utils +from fme.ace.data_loading.batch_data import BatchData, default_collate +from fme.ace.data_loading.config import DataLoaderConfig from fme.ace.inference.data_writer.monthly import ( MonthlyDataWriter, months_for_timesteps, ) -from fme.ace.inference.derived_variables import compute_derived_quantities -from fme.core.data_loading.config import DataLoaderConfig -from fme.core.data_loading.data_typing import SigmaCoordinates -from fme.core.data_loading.getters import get_datasets -from fme.core.data_loading.requirements import DataRequirements -from fme.core.data_loading.utils import BatchData +from fme.ace.stepper import AtmosphericDeriveFn +from fme.core.dataset.getters import get_datasets +from fme.core.dataset.requirements import DataRequirements +from fme.core.dataset.xarray import DatasetProperties from fme.core.device import using_gpu from fme.core.distributed import Distributed from fme.core.logging_utils import LoggingConfig +from fme.core.typing_ import TensorMapping + + +@dataclasses.dataclass +class CollateFn: + horizontal_dims: List[str] + + def __call__( + self, samples: Sequence[Tuple[TensorMapping, xr.DataArray]] + ) -> "BatchData": + sample_data, sample_time = zip(*samples) + batch_data = default_collate(sample_data) + batch_time = xr.concat(sample_time, dim="sample") + return BatchData( + data=batch_data, + time=batch_time, + horizontal_dims=self.horizontal_dims, + ) def get_data_loaders( config: DataLoaderConfig, requirements: DataRequirements -) -> List[torch.utils.data.DataLoader]: +) -> Tuple[List[torch.utils.data.DataLoader], DatasetProperties]: dist = Distributed.get_instance() if dist.world_size > 1: raise RuntimeError( @@ -36,7 +53,7 @@ def get_data_loaders( "supported in distributed mode." ) - datasets = get_datasets(config.dataset, requirements) + datasets, properties = get_datasets(config.dataset, requirements) data_loaders = [] for dataset in datasets: @@ -48,10 +65,12 @@ def get_data_loaders( sampler=None, drop_last=True, pin_memory=using_gpu(), - collate_fn=BatchData.from_sample_tuples, + collate_fn=CollateFn( + horizontal_dims=list(properties.horizontal_coordinates.dims), + ), ) data_loaders.append(dataloader) - return data_loaders + return data_loaders, properties def get_timesteps(data_loaders: List[torch.utils.data.DataLoader]) -> int: @@ -66,7 +85,7 @@ class Config: """ Configuration for applying the MonthlyDataWriter to a dataset. - Attributes: + Parameters: experiment_dir: Directory to save results to. dataset: Configuration for the dataset to load. num_data_workers: Number of parallel workers to use for data loading. @@ -88,7 +107,7 @@ def __post_init__(self): raise ValueError("Batch size must be 1 to write dataset using writer.") def get_data(self) -> "Data": - data_loaders = get_data_loaders( + data_loaders, properties = get_data_loaders( config=self.data_loader, requirements=DataRequirements( names=self.variable_names, @@ -98,8 +117,7 @@ def get_data(self) -> "Data": n_timesteps = get_timesteps(data_loaders=data_loaders) return Data( loaders=data_loaders, - sigma_coordinates=data_loaders[0].dataset.sigma_coordinates, - timestep=data_loaders[0].dataset.timestep, + properties=properties, n_timesteps=n_timesteps, ) @@ -107,10 +125,10 @@ def configure_logging(self, log_filename: str): self.logging.configure_logging(self.experiment_dir, log_filename) def get_data_writer(self, data: "Data") -> MonthlyDataWriter: - n_months = months_for_timesteps(data.n_timesteps, data.timestep) + n_months = months_for_timesteps(data.n_timesteps, data.properties.timestep) coords = { - **data.loaders[0].dataset.horizontal_coordinates.coords, - **data.loaders[0].dataset.sigma_coordinates.coords, + **data.properties.horizontal_coordinates.coords, + **data.properties.vertical_coordinate.coords, } return MonthlyDataWriter( path=self.experiment_dir, @@ -118,7 +136,7 @@ def get_data_writer(self, data: "Data") -> MonthlyDataWriter: save_names=None, # save all data given n_samples=self.data_loader.batch_size * len(data.loaders), n_months=n_months, - metadata=data.loaders[0].dataset.metadata, + variable_metadata=data.properties.variable_metadata, coords=coords, ) @@ -126,8 +144,7 @@ def get_data_writer(self, data: "Data") -> MonthlyDataWriter: @dataclasses.dataclass class Data: loaders: List[torch.utils.data.DataLoader] - sigma_coordinates: SigmaCoordinates - timestep: datetime.timedelta + properties: DatasetProperties n_timesteps: int @@ -135,12 +152,15 @@ def merge_loaders(loaders: List[torch.utils.data.DataLoader]): window_batch_data_list: List[BatchData] for window_batch_data_list in zip(*loaders): tensors = [item.data for item in window_batch_data_list] - times = [item.times for item in window_batch_data_list] + time = [item.time for item in window_batch_data_list] window_batch_data = { k: torch.concat([d[k] for d in tensors]) for k in tensors[0].keys() } - times = xr.concat(times, dim="sample") - yield BatchData(data=window_batch_data, times=times) + time = xr.concat(time, dim="sample") + yield BatchData( + data=window_batch_data, + time=time, + ) def run(config: Config): @@ -150,19 +170,25 @@ def run(config: Config): data = config.get_data() writer = config.get_data_writer(data) + derive_func = AtmosphericDeriveFn( + vertical_coordinate=data.properties.vertical_coordinate, + timestep=data.properties.timestep, + ) + n_batches = len(data.loaders[0].dataset) // config.data_loader.batch_size for i, window_batch_data in enumerate(merge_loaders(data.loaders)): # no need to trim initial conditions because # we set n_timesteps to 1 in the DataRequirements assert list(window_batch_data.data.values())[0].shape[1] == 1 - window_batch_data.data = compute_derived_quantities( - window_batch_data.data, data.sigma_coordinates, data.timestep + window_batch_data = window_batch_data.compute_derived_variables( + derive_func=derive_func, + forcing_data=window_batch_data, ) writer.append_batch( data=window_batch_data.data, start_timestep=-1, # ignored - batch_times=window_batch_data.times, + batch_time=window_batch_data.time, ) if i % 10 == 0: logging.info(f"Writing batch {i+1} of {n_batches}.") diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4d358e7..0000000 --- a/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -exclude = docs -ignore = E203,W293,W503,F541,E402 -max-line-length = 88