From 4a5e7de2478665960d8314fa020209049f981cc1 Mon Sep 17 00:00:00 2001 From: Ritu Yadav <40523539+RituYadav92@users.noreply.github.com> Date: Mon, 23 Sep 2024 00:03:20 +0200 Subject: [PATCH] LTAE fix in RegMTUPerNet --- segmentors/upernet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/segmentors/upernet.py b/segmentors/upernet.py index f09e26c2..bc1b1c84 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -430,8 +430,8 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): if self.encoder.model_name in ["satlas_pretrain"]: self.multi_temporal_strategy = None if self.multi_temporal_strategy == "ltae": - self.tmap = LTAE2d(positional_encoding=False, in_channels=self.encoder.embed_dim, - mlp=[self.encoder.embed_dim, self.encoder.embed_dim], d_model=self.encoder.embed_dim) + self.tmap = LTAE2d(positional_encoding=False, in_channels=cfg['in_channels'], + mlp=[cfg['in_channels'], cfg['in_channels']], d_model=cfg['in_channels']) elif self.multi_temporal_strategy == "linear": self.tmap = nn.Linear(self.multi_temporal, 1) else: @@ -447,15 +447,16 @@ def forward(self, img, output_shape=None): else: feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) - feats = [list(i) for i in zip(*feats)] - feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] + feats = [list(i) for i in zip(*feats)] + feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] + else: if not self.finetune: with torch.no_grad(): feats = self.encoder(img) else: feats = self.encoder(img) - + if self.encoder.model_name not in ["satlas_pretrain", "SpectralGPT"]: for i in range(len(feats)): if self.multi_temporal_strategy == "ltae":