From 210a5bc2ecf8620a9b7aad3e1da798d52d40ba4c Mon Sep 17 00:00:00 2001 From: yurujaja Date: Tue, 15 Oct 2024 16:29:33 +0200 Subject: [PATCH] minor fix for unet, and reg_upernet --- pangaea/decoders/upernet.py | 11 +++++++++-- pangaea/encoders/unet_encoder.py | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index 76ed41c..6540959 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -475,7 +475,12 @@ class RegUPerNet(Decoder): """ def __init__( - self, encoder: Encoder, finetune: bool, channels: int, pool_scales=(1, 2, 3, 6) + self, + encoder: Encoder, + finetune: bool, + channels: int, + pool_scales=(1, 2, 3, 6), + feature_multiplier: int = 1, ): super().__init__( encoder=encoder, @@ -644,14 +649,16 @@ def __init__( multi_temporal: bool | int, multi_temporal_strategy: str | None, pool_scales=(1, 2, 3, 6), + feature_multiplier: int = 1, ): super().__init__( encoder=encoder, finetune=finetune, channels=channels, pool_scales=pool_scales, + feature_multiplier=feature_multiplier, ) - + self.model_name = "Reg_MT_UPerNet" self.multi_temporal = multi_temporal self.multi_temporal_strategy = multi_temporal_strategy diff --git a/pangaea/encoders/unet_encoder.py b/pangaea/encoders/unet_encoder.py index 7c10e52..8281db6 100644 --- a/pangaea/encoders/unet_encoder.py +++ b/pangaea/encoders/unet_encoder.py @@ -35,14 +35,15 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=0, - output_dim=0, - output_layers=output_layers, + output_dim=output_dim, + output_layers=None, multi_temporal=False, # single time frame multi_temporal_output=False, pyramid_output=True, download_url=download_url, ) + # TODO: now only supports optical bands for single time frame self.in_channels = len(input_bands["optical"]) # number of optical bands self.topology = topology @@ -91,7 +92,6 @@ def __init__(self, topology: Sequence[int]): def forward(self, x1: torch.Tensor) -> list: inputs = [x1] - # Downward U: for layer in self.down_seq.values(): out = layer(inputs[-1]) inputs.append(out)