Skip to content

Commit

Permalink
110 add multi-temporal unet and ViT encoder (#114)
Browse files Browse the repository at this point in the history
* Add multi temporal unet encoder and corresponding config

* Add multi temporal ViT
  • Loading branch information
gle-bellier authored Oct 31, 2024
1 parent fcd7560 commit ecfefa1
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 47 deletions.
15 changes: 15 additions & 0 deletions configs/encoder/unet_encoder_mi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_target_: pangaea.encoders.unet_encoder.UNetMT
encoder_weights: null
download_url: null
input_size: ${dataset.img_size}
multi_temporal: ${dataset.multi_temporal}
topology: [64, 128, 256, 512,]

input_bands: ${dataset.bands}

output_dim:
- 64
- 128
- 256
- 512

26 changes: 26 additions & 0 deletions configs/encoder/vit_mi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_target_: pangaea.encoders.vit_encoder.VIT_EncoderMT
encoder_weights: ./pretrained_models/jx_vit_base_p16_224-80ecf9dd.pt
download_url: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth

multi_temporal: ${dataset.multi_temporal}
embed_dim: 768
input_size: 224
patch_size: 16
depth: 12
num_heads: 12
mlp_ratio: 4

input_bands:
optical:
- B4
- B3
- B2


output_layers:
- 3
- 5
- 7
- 11

output_dim: 768
61 changes: 61 additions & 0 deletions pangaea/encoders/unet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,67 @@ def load_encoder_weights(self, logger: Logger) -> None:
pass


class UNetMT(Encoder):
"""
Multi Temporal UNet Encoder for Supervised Baseline, to be trained from scratch.
It supports single time frame inputs with optical bands
Args:
input_bands (dict[str, list[str]]): Band names, specifically expecting the 'optical' key with a list of bands.
input_size (int): Size of the input images (height and width).
topology (Sequence[int]): The number of feature channels at each stage of the U-Net encoder.
"""

def __init__(
self,
input_bands: dict[str, list[str]],
input_size: int,
multi_temporal: int,
topology: Sequence[int],
output_dim: int | list[int],
download_url: str,
encoder_weights: str | None = None,
):
super().__init__(
model_name="unet_encoder",
encoder_weights=encoder_weights, # no pre-trained weights, train from scratch
input_bands=input_bands,
input_size=input_size,
embed_dim=0,
output_dim=output_dim,
output_layers=None,
multi_temporal=multi_temporal,
multi_temporal_output=False,
pyramid_output=True,
download_url=download_url,
)

self.in_channels = len(input_bands["optical"]) # number of optical bands
self.topology = topology

self.in_conv = InConv(self.in_channels, self.topology[0], DoubleConv)
self.encoder = UNet_Encoder(self.topology)

self.time_merging = DoubleConv(
in_ch=self.in_channels * self.multi_temporal, out_ch=self.in_channels
)

def forward(self, image):
x = image["optical"]
b, c, t, h, w = x.shape
# merge time and channels dimension
x = x.reshape(b, c * t, h, w)
x = self.time_merging(x)

feat = self.in_conv(x)
output = self.encoder(feat)
return output

def load_encoder_weights(self, logger: Logger) -> None:
pass


class UNet_Encoder(nn.Module):
"""
UNet Encoder class that defines the architecture of the encoder part of the UNet.
Expand Down
233 changes: 186 additions & 47 deletions pangaea/encoders/vit_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,79 @@

import torch
import torch.nn as nn
from timm.models.vision_transformer import Block, PatchEmbed

from pangaea.encoders.unet_encoder import DoubleConv

import timm.models.vision_transformer
from timm.models.vision_transformer import PatchEmbed, Block
from .base import Encoder


class VIT_Encoder(Encoder):
""" Vision Transformer with support for global average pooling
"""

def __init__(self,
encoder_weights,
input_size,
input_bands,
embed_dim,
output_layers,
output_dim,
download_url,
patch_size=16,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):

Encoder.__init__(self,
model_name="vit_encoder",
encoder_weights=encoder_weights,
input_bands=input_bands,
input_size=input_size,
embed_dim=embed_dim,
output_layers=output_layers,
output_dim=output_dim,
multi_temporal=False,
multi_temporal_output=False,
pyramid_output=False,
download_url=download_url
)
"""Vision Transformer with support for global average pooling"""

def __init__(
self,
encoder_weights,
input_size,
input_bands,
embed_dim,
output_layers,
output_dim,
download_url,
patch_size=16,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
):
Encoder.__init__(
self,
model_name="vit_encoder",
encoder_weights=encoder_weights,
input_bands=input_bands,
input_size=input_size,
embed_dim=embed_dim,
output_layers=output_layers,
output_dim=output_dim,
multi_temporal=False,
multi_temporal_output=False,
pyramid_output=False,
download_url=download_url,
)

self.patch_size = patch_size
self.patch_embed = PatchEmbed(input_size, patch_size, in_chans=3, embed_dim=embed_dim)
self.patch_embed = PatchEmbed(
input_size, patch_size, in_chans=3, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding

self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer)
for i in range(depth)])
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding

self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)


def forward(self, images):
x = images["optical"].squeeze(2)
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(
x.shape[0], -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed

Expand All @@ -82,16 +98,139 @@ def forward(self, images):

if i in self.output_layers:
out = x[:, 1:]
out = out.transpose(1, 2).view(
x.shape[0],
-1,
self.input_size // self.patch_size,
self.input_size // self.patch_size,
).contiguous()
out = (
out.transpose(1, 2)
.view(
x.shape[0],
-1,
self.input_size // self.patch_size,
self.input_size // self.patch_size,
)
.contiguous()
)
output.append(out)

return output

def load_encoder_weights(self, logger: Logger) -> None:
pretrained_model = torch.load(self.encoder_weights, map_location="cpu")
k = pretrained_model.keys()
pretrained_encoder = {}
incompatible_shape = {}
missing = {}
for name, param in self.named_parameters():
if name not in k:
missing[name] = param.shape
elif pretrained_model[name].shape != param.shape:
incompatible_shape[name] = (param.shape, pretrained_model[name].shape)
pretrained_model.pop(name)
else:
pretrained_encoder[name] = pretrained_model.pop(name)

self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)


class VIT_EncoderMT(Encoder):
"""Vision Transformer with support for global average pooling"""

def __init__(
self,
encoder_weights,
input_size,
input_bands,
embed_dim,
multi_temporal,
output_layers,
output_dim,
download_url,
patch_size=16,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
):
Encoder.__init__(
self,
model_name="vit_encoder",
encoder_weights=encoder_weights,
input_bands=input_bands,
input_size=input_size,
embed_dim=embed_dim,
output_layers=output_layers,
output_dim=output_dim,
multi_temporal=multi_temporal,
multi_temporal_output=False,
pyramid_output=False,
download_url=download_url,
)

self.patch_size = patch_size
self.in_channels = 3
self.patch_embed = PatchEmbed(
input_size, patch_size, in_chans=self.in_channels, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding

self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)

self.time_merging = DoubleConv(
in_ch=self.in_channels * self.multi_temporal, out_ch=self.in_channels
)

def forward(self, images):
x = images["optical"]
b, c, t, h, w = x.shape
# merge time and channels dimension
x = x.reshape(b, c * t, h, w)
x = self.time_merging(x)
x = self.patch_embed(x)

cls_tokens = self.cls_token.expand(
x.shape[0], -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed

output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i == len(self.blocks) - 1:
x = self.norm(x)

if i in self.output_layers:
out = x[:, 1:]
out = (
out.transpose(1, 2)
.view(
x.shape[0],
-1,
self.input_size // self.patch_size,
self.input_size // self.patch_size,
)
.contiguous()
)
output.append(out)

return output

def load_encoder_weights(self, logger: Logger) -> None:
pretrained_model = torch.load(self.encoder_weights, map_location="cpu")
Expand Down

0 comments on commit ecfefa1

Please sign in to comment.