Skip to content

Commit

Permalink
SpectralGPT for multi-temporal
Browse files Browse the repository at this point in the history
* Sets SpectralGPT encoder to monotemporal if UperNetCD is used. Otherwise, SpectralGPT expects bitemporal input, but the UperNetCD purposefully splits it into mono-temporal images, to compute separate features and then compute the difference between those two feature vectors.

* Moves spectralgpt compatibility ensurance to config-level preprocessing in utils/configs.py

* Reverts most of the changes made to ensure compatibility between SpectralGPT and change detection segmentors, in preparation of using a different strategy to handle this.

* MTUpernet semantic segmentation, treat Spectralgpt as a single temporal model

* Set num_frames=1 for SpectralGPT, instead of potentially using multitemporal inputs. This follows the original SpectralGPT implementation and makes it easier to solve a different problem that occurs when combining SpectralGPT with change detection segmentors.

* remove unnecessary repeat

---------

Co-authored-by: Yuru Jia <[email protected]>
  • Loading branch information
SebastianGer and yurujaja authored Sep 18, 2024
1 parent 5d07681 commit 3bdeda8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion foundation_models/spectralgpt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,

self.model_name = "SpectralGPT"
self.output_layers = cfg['output_layers']
self.num_frames = cfg['multi_temporal'] if cfg['multi_temporal'] else 1
self.num_frames = 1

self.patch_embed = PatchEmbed(
img_size, patch_size, self.num_frames, embed_dim, in_chans, t_patch_size)
Expand Down
41 changes: 26 additions & 15 deletions segmentors/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,31 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)):
def forward(self, img, output_shape=None):
"""Forward function for change detection."""

if self.encoder.model_name not in ["Prithvi", "satlas_pretrain", "SpectralGPT"]:
if self.encoder.model_name in ["Prithvi", "satlas_pretrain"]:
if not self.finetune:
with torch.no_grad():
feats = self.encoder(img)
else:
feats = self.encoder(img)
else:
feats = []
for i in range(self.multi_temporal):
if not self.finetune:
with torch.no_grad():
feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()}))
if self.encoder.model_name in ["SpectralGPT"]:
feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()}))
else:
feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()}))
else:
feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()}))
if self.encoder.model_name in ["SpectralGPT"]:
feats.append(self.encoder({k: v[:,:,[i],:,:] for k, v in img.items()}))
else:
feats.append(self.encoder({k: v[:,:,i,:,:] for k, v in img.items()}))

feats = [list(i) for i in zip(*feats)]
feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats]
else:
if not self.finetune:
with torch.no_grad():
feats = self.encoder(img)
else:
feats = self.encoder(img)

if self.encoder.model_name not in ["satlas_pretrain", "SpectralGPT"]:
feats = [torch.stack(feat_layers, dim = 2) for feat_layers in feats]

if self.encoder.model_name not in ["satlas_pretrain"]:
for i in range(len(feats)):
if self.multi_temporal_strategy == "ltae":
feats[i] = self.tmap(feats[i])
Expand Down Expand Up @@ -256,8 +262,13 @@ def forward(self, img, output_shape=None):
"""Forward function for change detection."""

if self.encoder.model_name != "Prithvi":
img1 = {k: v[:,:,0,:,:] for k, v in img.items()}
img2 = {k: v[:,:,1,:,:] for k, v in img.items()}
if self.encoder.model_name == "SpectralGPT":
# Retains the temporal dimension
img1 = {k: v[:,:,[0],:,:] for k, v in img.items()}
img2 = {k: v[:,:,[1],:,:] for k, v in img.items()}
else:
img1 = {k: v[:,:,0,:,:] for k, v in img.items()}
img2 = {k: v[:,:,1,:,:] for k, v in img.items()}

if not self.finetune:
with torch.no_grad():
Expand Down Expand Up @@ -413,4 +424,4 @@ def forward(self, inputs):
]
for i in range(len(inputs)):
outputs.append(ops[i](inputs[i]))
return tuple(outputs)
return tuple(outputs)

0 comments on commit 3bdeda8

Please sign in to comment.