Skip to content

Commit

Permalink
feat: use NamedNodesAttributes in AnemoiModelEncProcDec
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Oct 18, 2024
1 parent 6e64c91 commit 5c83f5f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 55 deletions.
17 changes: 13 additions & 4 deletions src/anemoi/models/layers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,31 @@ class NamedNodesAttributes(torch.nn.Module):

def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None:
"""Initialize NamedNodesAttributes."""
super().__init__()

self.num_trainable_params = num_trainable_params
self.nodes_names = list(graph_data.node_types)
self.register_fixed_attributes(graph_data)

self.trainable_tensors = nn.ModuleDict()
for nodes_name in self.nodes_names:
self.register_coordinates(nodes_name, graph_data[nodes_name].x)
self.register_tensor(nodes_name, graph_data[nodes_name].num_nodes)
self.register_tensor(nodes_name)

def register_fixed_attributes(self, graph_data: HeteroData) -> None:
"""Register fixed attributes."""
self.nodes_names = list(graph_data.node_types)
self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names}
self.coord_dims = {2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names}
self.attr_ndims = {self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names}

def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None:
"""Register coordinates."""
sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def register_tensor(self, name: str, tensor_size: int) -> None:
def register_tensor(self, name: str) -> None:
"""Register a trainable tensor."""
self.trainable_tensors[name] = TrainableTensor(tensor_size, self.num_trainable_params)
self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params)

def forward(self, name: str, batch_size: int) -> Tensor:
"""Forward pass."""
Expand Down
65 changes: 14 additions & 51 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.graph import NamedNodesAttributes

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,42 +55,33 @@ def __init__(

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon of nodes
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden],
hidden_dim=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Decoder hidden -> data
Expand All @@ -101,8 +92,8 @@ def __init__(
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._data_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
Expand Down Expand Up @@ -132,34 +123,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
self._internal_output_idx,
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.
Parameters
----------
name : str
Name to store the lat-lon coordinates of the nodes.
nodes : str
Name of nodes to map
"""
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
)

def _run_mapper(
self,
mapper: nn.Module,
Expand Down Expand Up @@ -209,12 +172,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
self.node_attributes(self._graph_name_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)
x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)
Expand Down

0 comments on commit 5c83f5f

Please sign in to comment.