diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index 6540959..3a753e8 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from pangaea.decoders.base import Decoder -from pangaea.decoders.ltae import LTAE2d +from pangaea.decoders.ltae import LTAE2d, LTAEChannelAdaptor from pangaea.encoders.base import Encoder @@ -26,6 +26,7 @@ def __init__( channels: int, pool_scales=(1, 2, 3, 6), feature_multiplier: int = 1, + in_channels: list[int] | None = None, ): super().__init__( encoder=encoder, @@ -45,14 +46,21 @@ def __init__( self.input_layers = self.encoder.output_layers self.input_layers_num = len(self.input_layers) - self.in_channels = [dim * feature_multiplier for dim in self.encoder.output_dim] + if in_channels is None: + self.in_channels = [ + dim * feature_multiplier for dim in self.encoder.output_dim + ] + else: + self.in_channels = [dim * feature_multiplier for dim in in_channels] if self.encoder.pyramid_output: rescales = [1 for _ in range(self.input_layers_num)] else: scales = [4, 2, 1, 0.5] - rescales = [scales[int(i / self.input_layers_num * 4)] for i in range(self.input_layers_num)] - + rescales = [ + scales[int(i / self.input_layers_num * 4)] + for i in range(self.input_layers_num) + ] self.neck = Feature2Pyramid( embed_dim=self.in_channels, @@ -249,6 +257,9 @@ def __init__( pool_scales: list[int] = [1, 2, 3, 6], feature_multiplier: int = 1, ) -> None: + decoder_in_channels = self.get_decoder_in_channels( + multi_temporal_strategy, encoder + ) super().__init__( encoder=encoder, num_classes=num_classes, @@ -256,6 +267,7 @@ def __init__( channels=channels, pool_scales=pool_scales, feature_multiplier=feature_multiplier, + in_channels=decoder_in_channels, ) self.multi_temporal = multi_temporal @@ -267,17 +279,37 @@ def __init__( self.tmap = None else: if self.multi_temporal_strategy == "ltae": + # if the encoder output channels vary we must use an adaptor before the LTAE + ltae_in_channels = max(encoder.output_dim) + if ltae_in_channels != min(encoder.output_dim): + self.ltae_adaptor = LTAEChannelAdaptor( + in_channels=encoder.output_dim, + out_channels=decoder_in_channels, + ) + else: + self.ltae_adaptor = lambda x: x + self.tmap = LTAE2d( positional_encoding=False, - in_channels=encoder.output_dim, - mlp=[encoder.output_dim, encoder.output_dim], - d_model=encoder.output_dim, + in_channels=ltae_in_channels, + mlp=[ltae_in_channels, ltae_in_channels], + d_model=ltae_in_channels, ) elif self.multi_temporal_strategy == "linear": self.tmap = nn.Linear(self.multi_temporal, 1) else: self.tmap = None + def get_decoder_in_channels( + self, multi_temporal_strategy: str | None, encoder: Encoder + ) -> list[int]: + if multi_temporal_strategy == "ltae": + # if the encoder output channels vary we must use an adaptor before the LTAE + ltae_in_channels = max(encoder.output_dim) + if ltae_in_channels != min(encoder.output_dim): + return [ltae_in_channels for _ in encoder.output_dim] + return encoder.output_dim + def forward( self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None ) -> torch.Tensor: @@ -319,14 +351,15 @@ def forward( ) feats = [list(i) for i in zip(*feats)] + # obtain features per layer feats = [torch.stack(feat_layers, dim=2) for feat_layers in feats] if self.tmap is not None: - for i in range(len(feats)): - if self.multi_temporal_strategy == "ltae": - feats[i] = self.tmap(feats[i]) - elif self.multi_temporal_strategy == "linear": - feats[i] = self.tmap(feats[i].permute(0, 1, 3, 4, 2)).squeeze(-1) + if self.multi_temporal_strategy == "ltae": + feats = self.ltae_adaptor(feats) + feats = [self.tmap(f) for f in feats] + elif self.multi_temporal_strategy == "linear": + feats = [self.tmap(f.permute(0, 1, 3, 4, 2)).squeeze(-1) for f in feats] feat = self.neck(feats) feat = self._forward_feature(feat) @@ -475,10 +508,10 @@ class RegUPerNet(Decoder): """ def __init__( - self, - encoder: Encoder, - finetune: bool, - channels: int, + self, + encoder: Encoder, + finetune: bool, + channels: int, pool_scales=(1, 2, 3, 6), feature_multiplier: int = 1, ): @@ -493,7 +526,6 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False - self.input_layers = self.encoder.output_layers self.input_layers_num = len(self.input_layers) @@ -503,11 +535,12 @@ def __init__( rescales = [1 for _ in range(self.input_layers_num)] else: scales = [4, 2, 1, 0.5] - rescales = [scales[int(i / self.input_layers_num * 4)] for i in range(self.input_layers_num)] + rescales = [ + scales[int(i / self.input_layers_num * 4)] + for i in range(self.input_layers_num) + ] - self.neck = Feature2Pyramid( - embed_dim=self.in_channels, rescales=rescales - ) + self.neck = Feature2Pyramid(embed_dim=self.in_channels, rescales=rescales) self.align_corners = False @@ -790,16 +823,26 @@ def __init__( for i, k in enumerate(self.rescales): if k == 4: - self.ops.append(nn.Sequential( - nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2), - nn.SyncBatchNorm(embed_dim[i]), - nn.GELU(), - nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2), - )) + self.ops.append( + nn.Sequential( + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ), + nn.SyncBatchNorm(embed_dim[i]), + nn.GELU(), + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ), + ) + ) elif k == 2: - self.ops.append(nn.Sequential( - nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2) - )) + self.ops.append( + nn.Sequential( + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ) + ) + ) elif k == 1: self.ops.append(nn.Identity()) elif k == 0.5: @@ -809,8 +852,6 @@ def __init__( else: raise KeyError(f"invalid {k} for feature2pyramid") - - def forward(self, inputs): assert len(inputs) == len(self.rescales) outputs = []