From 3bdeda8f7691e376b1e68762e1b4e090f2c44f4e Mon Sep 17 00:00:00 2001 From: SebastianGer Date: Wed, 18 Sep 2024 16:18:03 +0200 Subject: [PATCH] SpectralGPT for multi-temporal * Sets SpectralGPT encoder to monotemporal if UperNetCD is used. Otherwise, SpectralGPT expects bitemporal input, but the UperNetCD purposefully splits it into mono-temporal images, to compute separate features and then compute the difference between those two feature vectors. * Moves spectralgpt compatibility ensurance to config-level preprocessing in utils/configs.py * Reverts most of the changes made to ensure compatibility between SpectralGPT and change detection segmentors, in preparation of using a different strategy to handle this. * MTUpernet semantic segmentation, treat Spectralgpt as a single temporal model * Set num_frames=1 for SpectralGPT, instead of potentially using multitemporal inputs. This follows the original SpectralGPT implementation and makes it easier to solve a different problem that occurs when combining SpectralGPT with change detection segmentors. * remove unnecessary repeat --------- Co-authored-by: Yuru Jia <91590963+yurujaja@users.noreply.github.com> --- foundation_models/spectralgpt_encoder.py | 2 +- segmentors/upernet.py | 41 +++++++++++++++--------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/foundation_models/spectralgpt_encoder.py b/foundation_models/spectralgpt_encoder.py index 50f3aec1..220bd625 100644 --- a/foundation_models/spectralgpt_encoder.py +++ b/foundation_models/spectralgpt_encoder.py @@ -35,7 +35,7 @@ def __init__(self, self.model_name = "SpectralGPT" self.output_layers = cfg['output_layers'] - self.num_frames = cfg['multi_temporal'] if cfg['multi_temporal'] else 1 + self.num_frames = 1 self.patch_embed = PatchEmbed( img_size, patch_size, self.num_frames, embed_dim, in_chans, t_patch_size) diff --git a/segmentors/upernet.py b/segmentors/upernet.py index 425d8c98..fdfc26e6 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -202,25 +202,31 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): def forward(self, img, output_shape=None): """Forward function for change detection.""" - if self.encoder.model_name not in ["Prithvi", "satlas_pretrain", "SpectralGPT"]: + if self.encoder.model_name in ["Prithvi", "satlas_pretrain"]: + if not self.finetune: + with torch.no_grad(): + feats = self.encoder(img) + else: + feats = self.encoder(img) + else: feats = [] for i in range(self.multi_temporal): if not self.finetune: with torch.no_grad(): - feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) + if self.encoder.model_name in ["SpectralGPT"]: + feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()})) + else: + feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) else: - feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) + if self.encoder.model_name in ["SpectralGPT"]: + feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()})) + else: + feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()})) feats = [list(i) for i in zip(*feats)] - feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] - else: - if not self.finetune: - with torch.no_grad(): - feats = self.encoder(img) - else: - feats = self.encoder(img) - - if self.encoder.model_name not in ["satlas_pretrain", "SpectralGPT"]: + feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats] + + if self.encoder.model_name not in ["satlas_pretrain"]: for i in range(len(feats)): if self.multi_temporal_strategy == "ltae": feats[i] = self.tmap(feats[i]) @@ -256,8 +262,13 @@ def forward(self, img, output_shape=None): """Forward function for change detection.""" if self.encoder.model_name != "Prithvi": - img1 = {k: v[:,:,0,:,:] for k, v in img.items()} - img2 = {k: v[:,:,1,:,:] for k, v in img.items()} + if self.encoder.model_name == "SpectralGPT": + # Retains the temporal dimension + img1 = {k: v[:,:,[0],:,:] for k, v in img.items()} + img2 = {k: v[:,:,[1],:,:] for k, v in img.items()} + else: + img1 = {k: v[:,:,0,:,:] for k, v in img.items()} + img2 = {k: v[:,:,1,:,:] for k, v in img.items()} if not self.finetune: with torch.no_grad(): @@ -413,4 +424,4 @@ def forward(self, inputs): ] for i in range(len(inputs)): outputs.append(ops[i](inputs[i])) - return tuple(outputs) \ No newline at end of file + return tuple(outputs)