diff --git a/environment-dev.yml b/environment-dev.yml index c28e2881..4a6864a5 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -16,3 +16,4 @@ dependencies: - openmapflow[data]==0.2.4rc4 - dvc[gs] - fsspec==2022.11.0 # https://github.com/iterative/dvc-azure/issues/34 + - einops diff --git a/src/single_file_presto_v2.py b/src/single_file_presto_v2.py index dee43add..01b2e814 100644 --- a/src/single_file_presto_v2.py +++ b/src/single_file_presto_v2.py @@ -1,6 +1,7 @@ import math from collections import OrderedDict from copy import deepcopy +from enum import Enum from typing import Optional, Tuple, Union, cast import numpy as np @@ -29,6 +30,9 @@ NUM_DYNAMIC_WORLD_CLASSES = 9 +Aggregate = Enum("Aggregate", ["NONE", "MEAN", "BAND_GROUPS_MEAN"]) + + class Attention(nn.Module): # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py fast_attn: Final[bool] @@ -341,6 +345,23 @@ def mask_tokens(x, mask): return x, kept_indices, removed_indices + def band_groups_mean( + self, x: torch.Tensor, kept_indices: torch.Tensor, num_timesteps: int + ) -> torch.Tensor: + x = x[:, 1:, :] # latlon token - leave in or keep out? + batch_size, embedding_dim = x.shape[0], x.shape[-1] + cur_idx, groups = 0, [] + for channel_group, _ in self.band_groups.items(): + increment = num_timesteps if channel_group != "SRTM" else 1 + min_idx = cur_idx + max_idx = min_idx + increment + mask = (kept_indices >= min_idx) & (kept_indices < max_idx) + # we assume kept_elements is the same for all batches + kept_elements = sum(mask[0, :]) + groups.append(x[mask.bool()].view(batch_size, kept_elements, embedding_dim).mean(dim=1)) + cur_idx = max_idx + return torch.cat(groups, dim=1) + def forward( self, x: torch.Tensor, @@ -348,13 +369,14 @@ def forward( latlons: torch.Tensor, mask: Optional[torch.Tensor] = None, month: Union[torch.Tensor, int] = 0, - eval_task: bool = True, + aggregate: Aggregate = Aggregate.NONE, ): device = x.device if mask is None: mask = torch.zeros_like(x, device=x.device).float() + num_timesteps = x.shape[1] months = month_to_tensor(month, x.shape[0], x.shape[1], device) month_embedding = self.month_embed(months) positional_embedding = repeat( @@ -430,8 +452,10 @@ def forward( x = blk(x) # mask will be a boolean of shape [batch, total_num_tokens] - if eval_task: + if aggregate == Aggregate.MEAN: return self.norm(x.mean(dim=1)) + elif aggregate == Aggregate.BAND_GROUPS_MEAN: + return self.norm(self.band_groups_mean(x, kept_indices, num_timesteps)) return self.norm(x), kept_indices, removed_indices @@ -631,7 +655,7 @@ def forward(self, x, kept_indices, removed_indices, month): class PrestoFineTuningModel(nn.Module): - def __init__(self, encoder, head): + def __init__(self, encoder, head, aggregate: Aggregate): super().__init__() self.encoder: Encoder = deepcopy(encoder) # make sure the model is trainable, since we can call @@ -642,6 +666,7 @@ def __init__(self, encoder, head): self.encoder.pos_embed.requires_grad_(False) self.encoder.month_embed.requires_grad_(False) self.head = head + self.aggregate = aggregate def forward( self, @@ -658,7 +683,7 @@ def forward( latlons=latlons, mask=mask, month=month, - eval_task=True, + aggregate=self.aggregate, ) ) @@ -742,13 +767,22 @@ def construct_finetuning_model( self, num_outputs: int, regression: bool = False, + aggregate: Aggregate = Aggregate.BAND_GROUPS_MEAN, ): + if aggregate == Aggregate.MEAN: + hidden_size = self.encoder.embedding_size + elif aggregate == Aggregate.BAND_GROUPS_MEAN: + hidden_size = self.encoder_embedding_size * len(BANDS_GROUPS_IDX) + else: + raise ValueError head = FinetuningHead( num_outputs=num_outputs, - hidden_size=self.encoder.embedding_size, + hidden_size=hidden_size, regression=regression, ) - model = PrestoFineTuningModel(self.encoder, head).to(self.encoder.pos_embed.device) + model = PrestoFineTuningModel(self.encoder, head, aggregate).to( + self.encoder.pos_embed.device + ) model.train() return model diff --git a/test/unittest_presto.py b/test/unittest_presto.py new file mode 100644 index 00000000..5f6a25f7 --- /dev/null +++ b/test/unittest_presto.py @@ -0,0 +1,51 @@ +import os +import sys +import unittest + +import torch +from einops import repeat + +module_path = os.path.abspath(os.path.join("..")) +if module_path not in sys.path: + sys.path.append(module_path) + +from src.single_file_presto_v2 import BANDS_GROUPS_IDX, Presto # noqa: E402 + + +class PrestoTest(unittest.TestCase): + def test_band_groups_mean(self): + # hidden size = 1 + num_timesteps = 12 + x = torch.arange(-1, len(BANDS_GROUPS_IDX)).unsqueeze(-1).float() + x = torch.stack((x, x)) + cur_index, kept_indices = 0, [] + for band, _ in BANDS_GROUPS_IDX.items(): + kept_indices.append(cur_index) + if band == "SRTM": + cur_index += 1 + else: + cur_index += num_timesteps + kept_indices_t = torch.tensor(kept_indices) + kept_indices_t = torch.stack((kept_indices_t, kept_indices_t)) + model = Presto.construct() + out = model.encoder.band_groups_mean(x, kept_indices_t, num_timesteps) + expected_out = torch.arange(0, len(BANDS_GROUPS_IDX)) + expected_out = torch.stack((expected_out, expected_out)) + self.assertTrue(torch.equal(expected_out, out)) + + def test_band_groups_mean_d_128(self): + num_timesteps = 12 + x = torch.tensor([-1, 0, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8]).float() + x = repeat(x, "t -> b t d", b=2, d=128) + kept_indices = torch.tensor( + [ + [0, 1, 12, 24, 25, 26, 36, 48, 60, 72, 84, 85], + [0, 8, 12, 24, 25, 28, 36, 48, 60, 72, 84, 85], + ] + ) + model = Presto.construct() + out = model.encoder.band_groups_mean(x, kept_indices, num_timesteps) + expected_out = torch.arange(0, len(BANDS_GROUPS_IDX)) + expected_out = torch.repeat_interleave(expected_out, 128) + expected_out = torch.stack((expected_out, expected_out)) + self.assertTrue(torch.equal(expected_out, out))