Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/hierarchical graphs #37

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5869f18
Hard coded implementation of the hierarchical graph model
icedoom888 Aug 30, 2024
022948b
Added implementation of Hierarchical Graph networks
icedoom888 Sep 3, 2024
d74bc57
Added instantiate model in interface init
icedoom888 Sep 3, 2024
1d5e44a
if-else branching instead of hydra:instantiate, have to fix this in t…
icedoom888 Sep 4, 2024
b2e1f2b
Added changes before migration
icedoom888 Sep 5, 2024
a888f9b
WORKING implementation of Hierarchical graph network
icedoom888 Sep 11, 2024
dc7553c
Refactor and cleaning
icedoom888 Sep 11, 2024
7ba3519
Added example config
icedoom888 Sep 11, 2024
212d0a1
Merge branch 'ecmwf:develop' into feature/hierarchical_graphs
icedoom888 Sep 11, 2024
1b3376f
Minor refactor
icedoom888 Sep 11, 2024
e7deb30
Merge branch 'feature/hierarchical_graphs' of github.com:MeteoSwiss-A…
icedoom888 Sep 11, 2024
8533ac3
Refactor
icedoom888 Oct 9, 2024
32a49cd
Merge branch 'develop' into feature/hierarchical_graphs
icedoom888 Oct 17, 2024
3eb882b
Refactor and rebase
icedoom888 Oct 17, 2024
4aebd8f
Refactor and small changes for merge
icedoom888 Oct 17, 2024
53a23d4
Re-added asserts in mapper
icedoom888 Oct 17, 2024
368418a
Added entry in changelog
icedoom888 Oct 17, 2024
80d2835
Refactor pre-merge
icedoom888 Oct 17, 2024
689e3fe
Refactored the hierarchical model
icedoom888 Oct 23, 2024
32a0ac8
Test dimentions completed.
icedoom888 Oct 28, 2024
e58ef71
Fixed dynamo issue
icedoom888 Oct 30, 2024
95f59a1
Rebased on develop
icedoom888 Nov 25, 2024
8e6c00a
Refactored using NamedNodesAttributes
icedoom888 Nov 25, 2024
a257e37
Fixed with git pre-commit
icedoom888 Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Keep it human-readable, your future self will thank you!

### Added

- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37)
- Add anemoi-transform link to documentation
- Codeowners file
- Pygrep precommit hooks
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def forward(
*args,
**kwargs,
) -> Tensor:

shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels)
edge_attr = self.trainable(self.edge_attr, batch_size)

Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/models/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .encoder_processor_decoder import AnemoiModelEncProcDec
from .hierarchical import AnemoiModelEncProcDecHierarchical

__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"]
347 changes: 347 additions & 0 deletions src/anemoi/models/models/hierarchical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

Comment on lines +1 to +9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the copyright statement be updated in accordance with #67?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✔️

import logging
from typing import Optional

import einops
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.models import AnemoiModelEncProcDec

LOGGER = logging.getLogger(__name__)


class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec):
"""Message passing hierarchical graph neural network."""

def __init__(
self,
*,
model_config: DotDict,
data_indices: dict,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.

Parameters
----------
config : DotDict
Job configuration
data_indices : dict
Data indices
graph_data : HeteroData
Graph definition
"""
nn.Module.__init__(self)

JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
self._graph_data = graph_data

## Get hidden layers names
self._graph_name_data = model_config.graph.data
self._graph_hidden_names = model_config.graph.hidden
self.num_hidden = len(self._graph_hidden_names)

# Unpack config for hierarchical graph
self.level_process = model_config.model.enable_hierarchical_level_processing

# hidden_dims is the dimentionality of features at each depth
self.hidden_dims = {
hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names)
}

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)
self.multi_step = model_config.training.multistep_input
self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon for all nodes
self._register_latlon("data", "data")
for hidden in self._graph_hidden_names:
self._register_latlon(hidden, hidden)

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size

print(
input_dim,
getattr(self, f"latlons_{self._graph_hidden_names[0]}").shape[1],
self.hidden_dims[self._graph_hidden_names[0]],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a leftover debug?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed ✔️

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=getattr(self, f"latlons_{self._graph_hidden_names[0]}").shape[1]
+ self.trainable_hidden_size,
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]],
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_sizes[self._graph_hidden_names[0]],
)

# Level processors
if self.level_process:
self.down_level_processor = nn.ModuleDict()
self.up_level_processor = nn.ModuleDict()

for i in range(0, self.num_hidden):
nodes_names = self._graph_hidden_names[i]

self.down_level_processor[nodes_names] = instantiate(
model_config.model.processor,
num_channels=self.hidden_dims[nodes_names],
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)],
src_grid_size=self._hidden_grid_sizes[nodes_names],
dst_grid_size=self._hidden_grid_sizes[nodes_names],
num_layers=model_config.model.level_process_num_layers,
)

self.up_level_processor[nodes_names] = instantiate(
model_config.model.processor,
num_channels=self.hidden_dims[nodes_names],
sub_graph=self._graph_data[(nodes_names, "to", nodes_names)],
src_grid_size=self._hidden_grid_sizes[nodes_names],
dst_grid_size=self._hidden_grid_sizes[nodes_names],
num_layers=model_config.model.level_process_num_layers,
)

# delete final upscale (does not exist): |->|->|<-|<-|
del self.up_level_processor[nodes_names]

# Downscale
self.downscale = nn.ModuleDict()

for i in range(0, self.num_hidden - 1):
src_nodes_name = self._graph_hidden_names[i]
dst_nodes_name = self._graph_hidden_names[i + 1]

self.downscale[src_nodes_name] = instantiate(
model_config.model.encoder,
in_channels_src=self.hidden_dims[src_nodes_name],
in_channels_dst=getattr(self, f"latlons_{dst_nodes_name}").shape[1] + self.trainable_hidden_size,
hidden_dim=self.hidden_dims[dst_nodes_name],
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)],
src_grid_size=self._hidden_grid_sizes[src_nodes_name],
dst_grid_size=self._hidden_grid_sizes[dst_nodes_name],
)

# Upscale
self.upscale = nn.ModuleDict()

for i in range(1, self.num_hidden):
src_nodes_name = self._graph_hidden_names[i]
dst_nodes_name = self._graph_hidden_names[i - 1]

self.upscale[src_nodes_name] = instantiate(
model_config.model.decoder,
in_channels_src=self.hidden_dims[src_nodes_name],
in_channels_dst=self.hidden_dims[dst_nodes_name],
hidden_dim=self.hidden_dims[src_nodes_name],
out_channels_dst=self.hidden_dims[dst_nodes_name],
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)],
src_grid_size=self._hidden_grid_sizes[src_nodes_name],
dst_grid_size=self._hidden_grid_sizes[dst_nodes_name],
)

# Decoder hidden -> data
self.decoder = instantiate(
model_config.model.decoder,
in_channels_src=self.hidden_dims[self._graph_hidden_names[0]],
in_channels_dst=input_dim,
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]],
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_sizes[self._graph_hidden_names[0]],
dst_grid_size=self._data_grid_size,
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index)
for cfg in getattr(model_config.model, "bounding", [])
]
)

def _define_tensor_sizes(self, config: DotDict) -> None:

# Grid sizes
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_sizes = {}
for hidden in self._graph_hidden_names:
self._hidden_grid_sizes[hidden] = self._graph_data[hidden].num_nodes

# trainable sizes
self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = nn.ModuleDict()

for hidden in self._graph_hidden_names:
self.trainable_hidden[hidden] = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden]
)

def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor:
batch_size = x.shape[0]
ensemble_size = x.shape[2]

# add data positional info (lat/lon)
x_trainable_data = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

# Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias
x_trainable_hiddens = {}
for hidden in self._graph_hidden_names:
x_trainable_hiddens[hidden] = self.trainable_hidden[hidden](
getattr(self, f"latlons_{hidden}"), batch_size=batch_size
)

# Get data and hidden shapes for sharding
shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group)
shard_shapes_hiddens = {}
for hidden, x_latent in x_trainable_hiddens.items():
shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group)

# print('Input: ', x_trainable_data.shape, x_trainable_hiddens[self._graph_hidden_names[0]].shape, shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]])

# Run encoder
x_data_latent, curr_latent = self._run_mapper(
self.encoder,
(x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]),
batch_size=batch_size,
shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]),
model_comm_group=model_comm_group,
)

# print('After encoding: ', x_data_latent.shape, curr_latent.shape)

# Run processor
x_encoded_latents = {}
x_skip = {}

## Downscale
for i in range(0, self.num_hidden - 1):
src_hidden_name = self._graph_hidden_names[i]
dst_hidden_name = self._graph_hidden_names[i + 1]

# Processing at same level
if self.level_process:
curr_latent = self.down_level_processor[src_hidden_name](
curr_latent,
batch_size=batch_size,
shard_shapes=shard_shapes_hiddens[src_hidden_name],
model_comm_group=model_comm_group,
)
# print(f'After level of {src_hidden_name}: ', curr_latent.shape)

# store latents for skip connections
x_skip[src_hidden_name] = curr_latent

# Encode to next hidden level
x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper(
self.downscale[src_hidden_name],
(curr_latent, x_trainable_hiddens[dst_hidden_name]),
batch_size=batch_size,
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]),
model_comm_group=model_comm_group,
)

# print(f'After downscaling of {src_hidden_name}: ', curr_latent.shape)

# Processing hidden-most level
if self.level_process:
curr_latent = self.down_level_processor[dst_hidden_name](
curr_latent,
batch_size=batch_size,
shard_shapes=shard_shapes_hiddens[dst_hidden_name],
model_comm_group=model_comm_group,
)

# print(f'After level of {dst_hidden_name}: ', curr_latent.shape)

## Upscale
for i in range(self.num_hidden - 1, 0, -1):
src_hidden_name = self._graph_hidden_names[i]
dst_hidden_name = self._graph_hidden_names[i - 1]

# Process to next level
curr_latent = self._run_mapper(
self.upscale[src_hidden_name],
(curr_latent, x_encoded_latents[dst_hidden_name]),
batch_size=batch_size,
shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]),
model_comm_group=model_comm_group,
)

# print(f'After upscaling of {src_hidden_name}: ', curr_latent.shape)

# Add skip connections
curr_latent = curr_latent + x_skip[dst_hidden_name]

# Processing at same level
if self.level_process:
curr_latent = self.up_level_processor[dst_hidden_name](
curr_latent,
batch_size=batch_size,
shard_shapes=shard_shapes_hiddens[dst_hidden_name],
model_comm_group=model_comm_group,
)

# print(f'After level of {dst_hidden_name}: ', curr_latent.shape)

# Run decoder
x_out = self._run_mapper(
self.decoder,
(curr_latent, x_data_latent),
batch_size=batch_size,
shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data),
model_comm_group=model_comm_group,
)

# print('After decoding: ', x_out.shape)

x_out = (
einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
)
.to(dtype=x.dtype)
.clone()
)

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for bounding in self.boundings:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)

return x_out