Skip to content

Commit

Permalink
LTAE fix in RegMTUPerNet
Browse files Browse the repository at this point in the history
  • Loading branch information
RituYadav92 authored Sep 22, 2024
1 parent 7461e2d commit 4a5e7de
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions segmentors/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)):
if self.encoder.model_name in ["satlas_pretrain"]:
self.multi_temporal_strategy = None
if self.multi_temporal_strategy == "ltae":
self.tmap = LTAE2d(positional_encoding=False, in_channels=self.encoder.embed_dim,
mlp=[self.encoder.embed_dim, self.encoder.embed_dim], d_model=self.encoder.embed_dim)
self.tmap = LTAE2d(positional_encoding=False, in_channels=cfg['in_channels'],
mlp=[cfg['in_channels'], cfg['in_channels']], d_model=cfg['in_channels'])
elif self.multi_temporal_strategy == "linear":
self.tmap = nn.Linear(self.multi_temporal, 1)
else:
Expand All @@ -447,15 +447,16 @@ def forward(self, img, output_shape=None):
else:
feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()}))

feats = [list(i) for i in zip(*feats)]
feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats]
feats = [list(i) for i in zip(*feats)]
feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats]

else:
if not self.finetune:
with torch.no_grad():
feats = self.encoder(img)
else:
feats = self.encoder(img)

if self.encoder.model_name not in ["satlas_pretrain", "SpectralGPT"]:
for i in range(len(feats)):
if self.multi_temporal_strategy == "ltae":
Expand Down

0 comments on commit 4a5e7de

Please sign in to comment.