diff --git a/foundation_models/spectralgpt_encoder.py b/foundation_models/spectralgpt_encoder.py index 50f3aec1..220bd625 100644 --- a/foundation_models/spectralgpt_encoder.py +++ b/foundation_models/spectralgpt_encoder.py @@ -35,7 +35,7 @@ def __init__(self, self.model_name = "SpectralGPT" self.output_layers = cfg['output_layers'] - self.num_frames = cfg['multi_temporal'] if cfg['multi_temporal'] else 1 + self.num_frames = 1 self.patch_embed = PatchEmbed( img_size, patch_size, self.num_frames, embed_dim, in_chans, t_patch_size) diff --git a/segmentors/upernet.py b/segmentors/upernet.py index 425d8c98..fdfc26e6 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -202,25 +202,31 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): def forward(self, img, output_shape=None): """Forward function for change detection.""" - if self.encoder.model_name not in ["Prithvi", "satlas_pretrain", "SpectralGPT"]: + if self.encoder.model_name in ["Prithvi", "satlas_pretrain"]: + if not self.finetune: + with torch.no_grad(): + feats = self.encoder(img) + else: + feats = self.encoder(img) + else: feats = [] for i in range(self.multi_temporal): if not self.finetune: with torch.no_grad(): - feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) + if self.encoder.model_name in ["SpectralGPT"]: + feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()})) + else: + feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) else: - feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) + if self.encoder.model_name in ["SpectralGPT"]: + feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()})) + else: + feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) feats = [list(i) for i in zip(*feats)] - feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] - else: - if not self.finetune: - with torch.no_grad(): - feats = self.encoder(img) - else: - feats = self.encoder(img) - - if self.encoder.model_name not in ["satlas_pretrain", "SpectralGPT"]: + feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] + + if self.encoder.model_name not in ["satlas_pretrain"]: for i in range(len(feats)): if self.multi_temporal_strategy == "ltae": feats[i] = self.tmap(feats[i]) @@ -256,8 +262,13 @@ def forward(self, img, output_shape=None): """Forward function for change detection.""" if self.encoder.model_name != "Prithvi": - img1 = {k: v[:,:,0,:,:] for k, v in img.items()} - img2 = {k: v[:,:,1,:,:] for k, v in img.items()} + if self.encoder.model_name == "SpectralGPT": + # Retains the temporal dimension + img1 = {k: v[:,:,[0],:,:] for k, v in img.items()} + img2 = {k: v[:,:,[1],:,:] for k, v in img.items()} + else: + img1 = {k: v[:,:,0,:,:] for k, v in img.items()} + img2 = {k: v[:,:,1,:,:] for k, v in img.items()} if not self.finetune: with torch.no_grad(): @@ -413,4 +424,4 @@ def forward(self, inputs): ] for i in range(len(inputs)): outputs.append(ops[i](inputs[i])) - return tuple(outputs) \ No newline at end of file + return tuple(outputs)