-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/hierarchical graphs #37
Open
icedoom888
wants to merge
24
commits into
ecmwf:develop
Choose a base branch
from
MeteoSwiss:feature/hierarchical_graphs
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
5869f18
Hard coded implementation of the hierarchical graph model
icedoom888 022948b
Added implementation of Hierarchical Graph networks
icedoom888 d74bc57
Added instantiate model in interface init
icedoom888 1d5e44a
if-else branching instead of hydra:instantiate, have to fix this in t…
icedoom888 b2e1f2b
Added changes before migration
icedoom888 a888f9b
WORKING implementation of Hierarchical graph network
icedoom888 dc7553c
Refactor and cleaning
icedoom888 7ba3519
Added example config
icedoom888 212d0a1
Merge branch 'ecmwf:develop' into feature/hierarchical_graphs
icedoom888 1b3376f
Minor refactor
icedoom888 e7deb30
Merge branch 'feature/hierarchical_graphs' of github.com:MeteoSwiss-A…
icedoom888 8533ac3
Refactor
icedoom888 32a49cd
Merge branch 'develop' into feature/hierarchical_graphs
icedoom888 3eb882b
Refactor and rebase
icedoom888 4aebd8f
Refactor and small changes for merge
icedoom888 53a23d4
Re-added asserts in mapper
icedoom888 368418a
Added entry in changelog
icedoom888 80d2835
Refactor pre-merge
icedoom888 689e3fe
Refactored the hierarchical model
icedoom888 32a0ac8
Test dimentions completed.
icedoom888 e58ef71
Fixed dynamo issue
icedoom888 95f59a1
Rebased on develop
icedoom888 8e6c00a
Refactored using NamedNodesAttributes
icedoom888 a257e37
Fixed with git pre-commit
icedoom888 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .encoder_processor_decoder import AnemoiModelEncProcDec | ||
from .hierarchical import AnemoiModelEncProcDecHierarchical | ||
|
||
__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
import logging | ||
from typing import Optional | ||
|
||
import einops | ||
import torch | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from torch import Tensor | ||
from torch import nn | ||
from torch.distributed.distributed_c10d import ProcessGroup | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.models.distributed.shapes import get_shape_shards | ||
from anemoi.models.layers.graph import TrainableTensor | ||
from anemoi.models.models import AnemoiModelEncProcDec | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec): | ||
"""Message passing hierarchical graph neural network.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
model_config: DotDict, | ||
data_indices: dict, | ||
graph_data: HeteroData, | ||
) -> None: | ||
"""Initializes the graph neural network. | ||
|
||
Parameters | ||
---------- | ||
config : DotDict | ||
Job configuration | ||
data_indices : dict | ||
Data indices | ||
graph_data : HeteroData | ||
Graph definition | ||
""" | ||
nn.Module.__init__(self) | ||
|
||
JPXKQX marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._graph_data = graph_data | ||
|
||
## Get hidden layers names | ||
self._graph_name_data = model_config.graph.data | ||
self._graph_hidden_names = model_config.graph.hidden | ||
self.num_hidden = len(self._graph_hidden_names) | ||
|
||
# Unpack config for hierarchical graph | ||
self.level_process = model_config.model.enable_hierarchical_level_processing | ||
|
||
# hidden_dims is the dimentionality of features at each depth | ||
self.hidden_dims = { | ||
hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names) | ||
} | ||
|
||
self._calculate_shapes_and_indices(data_indices) | ||
self._assert_matching_indices(data_indices) | ||
self.multi_step = model_config.training.multistep_input | ||
self._define_tensor_sizes(model_config) | ||
|
||
# Create trainable tensors | ||
self._create_trainable_attributes() | ||
|
||
# Register lat/lon for all nodes | ||
self._register_latlon("data", "data") | ||
for hidden in self._graph_hidden_names: | ||
self._register_latlon(hidden, hidden) | ||
|
||
input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size | ||
|
||
print( | ||
input_dim, | ||
getattr(self, f"latlons_{self._graph_hidden_names[0]}").shape[1], | ||
self.hidden_dims[self._graph_hidden_names[0]], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a leftover debug? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed ✔️ |
||
# Encoder data -> hidden | ||
self.encoder = instantiate( | ||
model_config.model.encoder, | ||
in_channels_src=input_dim, | ||
in_channels_dst=getattr(self, f"latlons_{self._graph_hidden_names[0]}").shape[1] | ||
+ self.trainable_hidden_size, | ||
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], | ||
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])], | ||
src_grid_size=self._data_grid_size, | ||
dst_grid_size=self._hidden_grid_sizes[self._graph_hidden_names[0]], | ||
) | ||
|
||
# Level processors | ||
if self.level_process: | ||
self.down_level_processor = nn.ModuleDict() | ||
self.up_level_processor = nn.ModuleDict() | ||
|
||
for i in range(0, self.num_hidden): | ||
nodes_names = self._graph_hidden_names[i] | ||
|
||
self.down_level_processor[nodes_names] = instantiate( | ||
model_config.model.processor, | ||
num_channels=self.hidden_dims[nodes_names], | ||
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], | ||
src_grid_size=self._hidden_grid_sizes[nodes_names], | ||
dst_grid_size=self._hidden_grid_sizes[nodes_names], | ||
num_layers=model_config.model.level_process_num_layers, | ||
) | ||
|
||
self.up_level_processor[nodes_names] = instantiate( | ||
model_config.model.processor, | ||
num_channels=self.hidden_dims[nodes_names], | ||
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], | ||
src_grid_size=self._hidden_grid_sizes[nodes_names], | ||
dst_grid_size=self._hidden_grid_sizes[nodes_names], | ||
num_layers=model_config.model.level_process_num_layers, | ||
) | ||
|
||
# delete final upscale (does not exist): |->|->|<-|<-| | ||
del self.up_level_processor[nodes_names] | ||
|
||
# Downscale | ||
self.downscale = nn.ModuleDict() | ||
|
||
for i in range(0, self.num_hidden - 1): | ||
src_nodes_name = self._graph_hidden_names[i] | ||
dst_nodes_name = self._graph_hidden_names[i + 1] | ||
|
||
self.downscale[src_nodes_name] = instantiate( | ||
model_config.model.encoder, | ||
in_channels_src=self.hidden_dims[src_nodes_name], | ||
in_channels_dst=getattr(self, f"latlons_{dst_nodes_name}").shape[1] + self.trainable_hidden_size, | ||
hidden_dim=self.hidden_dims[dst_nodes_name], | ||
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], | ||
src_grid_size=self._hidden_grid_sizes[src_nodes_name], | ||
dst_grid_size=self._hidden_grid_sizes[dst_nodes_name], | ||
) | ||
|
||
# Upscale | ||
self.upscale = nn.ModuleDict() | ||
|
||
for i in range(1, self.num_hidden): | ||
src_nodes_name = self._graph_hidden_names[i] | ||
dst_nodes_name = self._graph_hidden_names[i - 1] | ||
|
||
self.upscale[src_nodes_name] = instantiate( | ||
model_config.model.decoder, | ||
in_channels_src=self.hidden_dims[src_nodes_name], | ||
in_channels_dst=self.hidden_dims[dst_nodes_name], | ||
hidden_dim=self.hidden_dims[src_nodes_name], | ||
out_channels_dst=self.hidden_dims[dst_nodes_name], | ||
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], | ||
src_grid_size=self._hidden_grid_sizes[src_nodes_name], | ||
dst_grid_size=self._hidden_grid_sizes[dst_nodes_name], | ||
) | ||
|
||
# Decoder hidden -> data | ||
self.decoder = instantiate( | ||
model_config.model.decoder, | ||
in_channels_src=self.hidden_dims[self._graph_hidden_names[0]], | ||
in_channels_dst=input_dim, | ||
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], | ||
out_channels_dst=self.num_output_channels, | ||
sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)], | ||
src_grid_size=self._hidden_grid_sizes[self._graph_hidden_names[0]], | ||
dst_grid_size=self._data_grid_size, | ||
) | ||
|
||
# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) | ||
self.boundings = nn.ModuleList( | ||
[ | ||
instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index) | ||
for cfg in getattr(model_config.model, "bounding", []) | ||
] | ||
) | ||
|
||
def _define_tensor_sizes(self, config: DotDict) -> None: | ||
|
||
# Grid sizes | ||
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes | ||
self._hidden_grid_sizes = {} | ||
for hidden in self._graph_hidden_names: | ||
self._hidden_grid_sizes[hidden] = self._graph_data[hidden].num_nodes | ||
|
||
# trainable sizes | ||
self.trainable_data_size = config.model.trainable_parameters.data | ||
self.trainable_hidden_size = config.model.trainable_parameters.hidden | ||
|
||
def _create_trainable_attributes(self) -> None: | ||
"""Create all trainable attributes.""" | ||
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) | ||
self.trainable_hidden = nn.ModuleDict() | ||
|
||
for hidden in self._graph_hidden_names: | ||
self.trainable_hidden[hidden] = TrainableTensor( | ||
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden] | ||
) | ||
|
||
def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: | ||
batch_size = x.shape[0] | ||
ensemble_size = x.shape[2] | ||
|
||
# add data positional info (lat/lon) | ||
x_trainable_data = torch.cat( | ||
( | ||
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), | ||
self.trainable_data(self.latlons_data, batch_size=batch_size), | ||
), | ||
dim=-1, # feature dimension | ||
) | ||
|
||
# Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias | ||
x_trainable_hiddens = {} | ||
for hidden in self._graph_hidden_names: | ||
x_trainable_hiddens[hidden] = self.trainable_hidden[hidden]( | ||
getattr(self, f"latlons_{hidden}"), batch_size=batch_size | ||
) | ||
|
||
# Get data and hidden shapes for sharding | ||
shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group) | ||
shard_shapes_hiddens = {} | ||
for hidden, x_latent in x_trainable_hiddens.items(): | ||
shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group) | ||
|
||
# print('Input: ', x_trainable_data.shape, x_trainable_hiddens[self._graph_hidden_names[0]].shape, shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]) | ||
|
||
# Run encoder | ||
x_data_latent, curr_latent = self._run_mapper( | ||
self.encoder, | ||
(x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print('After encoding: ', x_data_latent.shape, curr_latent.shape) | ||
|
||
# Run processor | ||
x_encoded_latents = {} | ||
x_skip = {} | ||
|
||
## Downscale | ||
for i in range(0, self.num_hidden - 1): | ||
src_hidden_name = self._graph_hidden_names[i] | ||
dst_hidden_name = self._graph_hidden_names[i + 1] | ||
|
||
# Processing at same level | ||
if self.level_process: | ||
curr_latent = self.down_level_processor[src_hidden_name]( | ||
curr_latent, | ||
batch_size=batch_size, | ||
shard_shapes=shard_shapes_hiddens[src_hidden_name], | ||
model_comm_group=model_comm_group, | ||
) | ||
# print(f'After level of {src_hidden_name}: ', curr_latent.shape) | ||
|
||
# store latents for skip connections | ||
x_skip[src_hidden_name] = curr_latent | ||
|
||
# Encode to next hidden level | ||
x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper( | ||
self.downscale[src_hidden_name], | ||
(curr_latent, x_trainable_hiddens[dst_hidden_name]), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print(f'After downscaling of {src_hidden_name}: ', curr_latent.shape) | ||
|
||
# Processing hidden-most level | ||
if self.level_process: | ||
curr_latent = self.down_level_processor[dst_hidden_name]( | ||
curr_latent, | ||
batch_size=batch_size, | ||
shard_shapes=shard_shapes_hiddens[dst_hidden_name], | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print(f'After level of {dst_hidden_name}: ', curr_latent.shape) | ||
|
||
## Upscale | ||
for i in range(self.num_hidden - 1, 0, -1): | ||
src_hidden_name = self._graph_hidden_names[i] | ||
dst_hidden_name = self._graph_hidden_names[i - 1] | ||
|
||
# Process to next level | ||
curr_latent = self._run_mapper( | ||
self.upscale[src_hidden_name], | ||
(curr_latent, x_encoded_latents[dst_hidden_name]), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print(f'After upscaling of {src_hidden_name}: ', curr_latent.shape) | ||
|
||
# Add skip connections | ||
curr_latent = curr_latent + x_skip[dst_hidden_name] | ||
|
||
# Processing at same level | ||
if self.level_process: | ||
curr_latent = self.up_level_processor[dst_hidden_name]( | ||
curr_latent, | ||
batch_size=batch_size, | ||
shard_shapes=shard_shapes_hiddens[dst_hidden_name], | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print(f'After level of {dst_hidden_name}: ', curr_latent.shape) | ||
|
||
# Run decoder | ||
x_out = self._run_mapper( | ||
self.decoder, | ||
(curr_latent, x_data_latent), | ||
batch_size=batch_size, | ||
shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data), | ||
model_comm_group=model_comm_group, | ||
) | ||
|
||
# print('After decoding: ', x_out.shape) | ||
|
||
x_out = ( | ||
einops.rearrange( | ||
x_out, | ||
"(batch ensemble grid) vars -> batch ensemble grid vars", | ||
batch=batch_size, | ||
ensemble=ensemble_size, | ||
) | ||
.to(dtype=x.dtype) | ||
.clone() | ||
) | ||
|
||
# residual connection (just for the prognostic variables) | ||
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] | ||
|
||
for bounding in self.boundings: | ||
# bounding performed in the order specified in the config file | ||
x_out = bounding(x_out) | ||
|
||
return x_out |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the copyright statement be updated in accordance with #67?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done ✔️