Skip to content

Commit

Permalink
Merge branch 'develop' into feature/reduce-decoder-mem-usage
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Nov 28, 2024
2 parents 7a86cf3 + fd2bcf1 commit 0fb033a
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you!
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)

Expand Down
13 changes: 12 additions & 1 deletion src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
super().__init__(config, data_indices, statistics)

self.nan_locations = None
# weight imputed values wiht zero in loss calculation
self.loss_mask_training = None

def _validate_indices(self):
assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), (
Expand Down Expand Up @@ -109,12 +111,21 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()

# Initilialize mask once
# Initialize nan mask once
if self.nan_locations is None:
# The mask is only saved for the last two dimensions (grid, variable)
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
self.nan_locations = torch.isnan(x[idx].squeeze())

# Initialize training loss mask to weigh imputed values with zeroes once
self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
) # shape (grid, n_outputs)
# for all variables that are imputed and part of the model output, set the loss weight to zero
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
if idx_dst is not None:
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
Expand Down
29 changes: 29 additions & 0 deletions src/anemoi/models/preprocessing/remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,35 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:

return x_remapped

def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Remap the loss mask.
```
x : torch.Tensor
Loss mask
```
"""
# use indices at model output level
index = self.index_inference_backmapped_output
indices_remapped = self.index_inference_output
indices_keep = self.indices_keep_inference_output

# create new loss mask with target number of columns
mask_remapped = torch.zeros(
mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device
)

# copy loss mask for variables that are not remapped
mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep]

# remap loss mask for rest of variables
for idx_src, idx_dst in zip(indices_remapped, index):
if idx_dst is not None:
for ii in idx_dst:
mask_remapped[..., ii] = mask[..., idx_src]

return mask_remapped

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Convert and remap the output tensor.
Expand Down
20 changes: 20 additions & 0 deletions tests/preprocessing/test_preprocessor_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,26 @@ def test_mask_saving(imputer_fixture, data_fixture, request):
assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
)
def test_loss_nan_mask(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
x, _ = request.getfixturevalue(data_fixture)
expected = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) # only prognostic and diagnostic variables
imputer = request.getfixturevalue(imputer_fixture)
imputer.transform(x)
assert torch.allclose(
imputer.loss_mask_training, expected
), "Transform does not calculate NaN-mask for loss function scaling correctly."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
Expand Down
40 changes: 40 additions & 0 deletions tests/preprocessing/test_preprocessor_remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
# nor does it submit to any jurisdiction.


import numpy as np
import pytest
import torch
from omegaconf import DictConfig

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing.imputer import InputImputer
from anemoi.models.preprocessing.remapper import Remapper


Expand Down Expand Up @@ -41,6 +43,34 @@ def input_remapper():
return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics)


@pytest.fixture()
def input_imputer():
config = DictConfig(
{
"diagnostics": {"log": {"code": {"level": "DEBUG"}}},
"data": {
"remapper": {
"cos_sin": {
"d": ["cos_d", "sin_d"],
}
},
"imputer": {"default": "none", "mean": ["y", "d"]},
"forcing": ["z", "q"],
"diagnostic": ["other"],
"remapped": {
"d": ["cos_d", "sin_d"],
},
},
},
)
statistics = {
"mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0, 1.0]),
}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5}
data_indices = IndexCollection(config=config, name_to_index=name_to_index)
return InputImputer(config=config.data.imputer, data_indices=data_indices, statistics=statistics)


def test_remap_not_inplace(input_remapper) -> None:
x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])
input_remapper(x, in_place=False)
Expand All @@ -66,3 +96,13 @@ def test_remap_inverse_transform(input_remapper) -> None:
assert torch.allclose(
input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x
)


def test_transform_loss_mask(input_imputer, input_remapper) -> None:
x = torch.Tensor([[1.0, np.nan, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, np.nan, 10.0]])
expected_output = torch.Tensor([[1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 0.0]])
input_imputer.transform(x)
input_remapper.transform(x)
loss_mask_training = input_imputer.loss_mask_training
loss_mask_training = input_remapper.transform_loss_mask(loss_mask_training)
assert torch.allclose(loss_mask_training, expected_output)

0 comments on commit 0fb033a

Please sign in to comment.