Skip to content

Commit

Permalink
Merge pull request #382 from nasaharvest/stacked-channels
Browse files Browse the repository at this point in the history
Presto - stack channel groups instead of averaging them
  • Loading branch information
gabrieltseng authored Jan 30, 2024
2 parents 1983c42 + 756ecc6 commit bd5028c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 6 deletions.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ dependencies:
- openmapflow[data]==0.2.4rc4
- dvc[gs]
- fsspec==2022.11.0 # https://github.com/iterative/dvc-azure/issues/34
- einops
46 changes: 40 additions & 6 deletions src/single_file_presto_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from typing import Optional, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -29,6 +30,9 @@
NUM_DYNAMIC_WORLD_CLASSES = 9


Aggregate = Enum("Aggregate", ["NONE", "MEAN", "BAND_GROUPS_MEAN"])


class Attention(nn.Module):
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
fast_attn: Final[bool]
Expand Down Expand Up @@ -341,20 +345,38 @@ def mask_tokens(x, mask):

return x, kept_indices, removed_indices

def band_groups_mean(
self, x: torch.Tensor, kept_indices: torch.Tensor, num_timesteps: int
) -> torch.Tensor:
x = x[:, 1:, :] # latlon token - leave in or keep out?
batch_size, embedding_dim = x.shape[0], x.shape[-1]
cur_idx, groups = 0, []
for channel_group, _ in self.band_groups.items():
increment = num_timesteps if channel_group != "SRTM" else 1
min_idx = cur_idx
max_idx = min_idx + increment
mask = (kept_indices >= min_idx) & (kept_indices < max_idx)
# we assume kept_elements is the same for all batches
kept_elements = sum(mask[0, :])
groups.append(x[mask.bool()].view(batch_size, kept_elements, embedding_dim).mean(dim=1))
cur_idx = max_idx
return torch.cat(groups, dim=1)

def forward(
self,
x: torch.Tensor,
dynamic_world: torch.Tensor,
latlons: torch.Tensor,
mask: Optional[torch.Tensor] = None,
month: Union[torch.Tensor, int] = 0,
eval_task: bool = True,
aggregate: Aggregate = Aggregate.NONE,
):
device = x.device

if mask is None:
mask = torch.zeros_like(x, device=x.device).float()

num_timesteps = x.shape[1]
months = month_to_tensor(month, x.shape[0], x.shape[1], device)
month_embedding = self.month_embed(months)
positional_embedding = repeat(
Expand Down Expand Up @@ -430,8 +452,10 @@ def forward(
x = blk(x)

# mask will be a boolean of shape [batch, total_num_tokens]
if eval_task:
if aggregate == Aggregate.MEAN:
return self.norm(x.mean(dim=1))
elif aggregate == Aggregate.BAND_GROUPS_MEAN:
return self.norm(self.band_groups_mean(x, kept_indices, num_timesteps))
return self.norm(x), kept_indices, removed_indices


Expand Down Expand Up @@ -631,7 +655,7 @@ def forward(self, x, kept_indices, removed_indices, month):


class PrestoFineTuningModel(nn.Module):
def __init__(self, encoder, head):
def __init__(self, encoder, head, aggregate: Aggregate):
super().__init__()
self.encoder: Encoder = deepcopy(encoder)
# make sure the model is trainable, since we can call
Expand All @@ -642,6 +666,7 @@ def __init__(self, encoder, head):
self.encoder.pos_embed.requires_grad_(False)
self.encoder.month_embed.requires_grad_(False)
self.head = head
self.aggregate = aggregate

def forward(
self,
Expand All @@ -658,7 +683,7 @@ def forward(
latlons=latlons,
mask=mask,
month=month,
eval_task=True,
aggregate=self.aggregate,
)
)

Expand Down Expand Up @@ -742,13 +767,22 @@ def construct_finetuning_model(
self,
num_outputs: int,
regression: bool = False,
aggregate: Aggregate = Aggregate.BAND_GROUPS_MEAN,
):
if aggregate == Aggregate.MEAN:
hidden_size = self.encoder.embedding_size
elif aggregate == Aggregate.BAND_GROUPS_MEAN:
hidden_size = self.encoder_embedding_size * len(BANDS_GROUPS_IDX)
else:
raise ValueError
head = FinetuningHead(
num_outputs=num_outputs,
hidden_size=self.encoder.embedding_size,
hidden_size=hidden_size,
regression=regression,
)
model = PrestoFineTuningModel(self.encoder, head).to(self.encoder.pos_embed.device)
model = PrestoFineTuningModel(self.encoder, head, aggregate).to(
self.encoder.pos_embed.device
)
model.train()
return model

Expand Down
51 changes: 51 additions & 0 deletions test/unittest_presto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import sys
import unittest

import torch
from einops import repeat

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
sys.path.append(module_path)

from src.single_file_presto_v2 import BANDS_GROUPS_IDX, Presto # noqa: E402


class PrestoTest(unittest.TestCase):
def test_band_groups_mean(self):
# hidden size = 1
num_timesteps = 12
x = torch.arange(-1, len(BANDS_GROUPS_IDX)).unsqueeze(-1).float()
x = torch.stack((x, x))
cur_index, kept_indices = 0, []
for band, _ in BANDS_GROUPS_IDX.items():
kept_indices.append(cur_index)
if band == "SRTM":
cur_index += 1
else:
cur_index += num_timesteps
kept_indices_t = torch.tensor(kept_indices)
kept_indices_t = torch.stack((kept_indices_t, kept_indices_t))
model = Presto.construct()
out = model.encoder.band_groups_mean(x, kept_indices_t, num_timesteps)
expected_out = torch.arange(0, len(BANDS_GROUPS_IDX))
expected_out = torch.stack((expected_out, expected_out))
self.assertTrue(torch.equal(expected_out, out))

def test_band_groups_mean_d_128(self):
num_timesteps = 12
x = torch.tensor([-1, 0, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8]).float()
x = repeat(x, "t -> b t d", b=2, d=128)
kept_indices = torch.tensor(
[
[0, 1, 12, 24, 25, 26, 36, 48, 60, 72, 84, 85],
[0, 8, 12, 24, 25, 28, 36, 48, 60, 72, 84, 85],
]
)
model = Presto.construct()
out = model.encoder.band_groups_mean(x, kept_indices, num_timesteps)
expected_out = torch.arange(0, len(BANDS_GROUPS_IDX))
expected_out = torch.repeat_interleave(expected_out, 128)
expected_out = torch.stack((expected_out, expected_out))
self.assertTrue(torch.equal(expected_out, out))

0 comments on commit bd5028c

Please sign in to comment.