Skip to content

Commit

Permalink
fix spectralgpt, sampling, and minor bugs, enable replicating dataset…
Browse files Browse the repository at this point in the history
… in case it is smaller than batch size (#102)

* fix spectralgpt and minor bugs

* sampling before preprocessing and add data replicate for dataset smaller than batch size (after sampling)

* minor fixe
  • Loading branch information
LeungTsang authored Oct 18, 2024
1 parent 7108101 commit 52fd21a
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 31 deletions.
5 changes: 4 additions & 1 deletion configs/encoder/ssl4eo_data2vec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ patch_size: 16
num_heads: 6
depth: 12
mlp_ratio: 4
init_values: 0.1
use_abs_pos_emb: False
use_shared_rel_pos_bias: True

input_bands:
optical:
Expand All @@ -32,4 +35,4 @@ output_layers:
- 7
- 11

output_dim: 384
output_dim: 384
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limited_label_train: 1
limited_label_val: 1
limited_label_strategy: stratified # Options: stratified, oversampled, random
stratification_bins: 3 # number of bins for stratified sampling, only for stratified

data_replicate: 1


defaults:
Expand Down
36 changes: 20 additions & 16 deletions pangaea/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,29 @@ def download(self) -> None:
raise NotImplementedError


class GeoFMSubset(Subset):
"""Custom subset class that retains dataset attributes."""

def __init__(self, dataset, indices):
super().__init__(dataset, indices)

# Copy relevant attributes from the original dataset
self.__dict__.update(dataset.__dict__)

def filter_by_indices(self, indices):
"""Apply filtering by indices directly in this subset."""
return GeoFMSubset(self.dataset, indices)



class GeoFMDataset(Dataset):
"""Base class for all datasets."""

def __init__(
self,
dataset: RawGeoFMDataset,
dataset: RawGeoFMDataset | GeoFMSubset,
preprocessor: Preprocessor = None,
replicate: int = None,
):
"""Initializes the dataset.
Expand All @@ -139,6 +155,7 @@ def __init__(
self.__dict__.update(dataset.__dict__)
self.raw_dataset = dataset
self.preprocessor = preprocessor
self.replicate = replicate if replicate is not None else 1

def __len__(self) -> int:
"""Returns the length of the dataset.
Expand All @@ -147,7 +164,7 @@ def __len__(self) -> int:
int: length of the dataset
"""

return len(self.raw_dataset)
return len(self.raw_dataset) * self.replicate

def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Returns the i-th item of the dataset.
Expand All @@ -170,22 +187,9 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor
"metadata": dict}.
"""

output = self.raw_dataset[i]
output = self.raw_dataset[i // self.replicate]
if self.preprocessor is not None:
output = self.preprocessor(output)

return output


class GeoFMSubset(Subset):
"""Custom subset class that retains dataset attributes."""

def __init__(self, dataset, indices):
super().__init__(dataset, indices)

# Copy relevant attributes from the original dataset
self.__dict__.update(dataset.__dict__)

def filter_by_indices(self, indices):
"""Apply filtering by indices directly in this subset."""
return GeoFMSubset(self.dataset, indices)
2 changes: 1 addition & 1 deletion pangaea/datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(

self.split_mapping = {
"train": "training",
"val": "validation",
"val": "training",
"test": "validation",
}

Expand Down
6 changes: 2 additions & 4 deletions pangaea/encoders/spectralgpt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,12 @@ def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
out = x
out = out.view(N, T, L, C).transpose(2, 3).flatten(1, 2)
out = (
out.permute(0, 2, 1)
.contiguous()
.view(
out.view(
x.shape[0],
-1,
self.input_size // self.patch_size,
self.input_size // self.patch_size,
)
).contiguous()
)
output.append(out)

Expand Down
18 changes: 10 additions & 8 deletions pangaea/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,32 +161,34 @@ def main(cfg: DictConfig) -> None:
# get datasets
raw_train_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="train")
raw_val_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="val")
train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor)
val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor)

logger.info("Built {} dataset.".format(cfg.dataset.dataset_name))

if 0 < cfg.limited_label_train < 1:
indices = get_subset_indices(
train_dataset,
raw_train_dataset,
task=task_name,
strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_train,
num_bins=cfg.stratification_bins,
logger=logger,
)
train_dataset = GeoFMSubset(train_dataset, indices)
raw_train_dataset = GeoFMSubset(raw_train_dataset, indices)

if 0 < cfg.limited_label_val < 1:
indices = get_subset_indices(
val_dataset,
raw_val_dataset,
task=task_name,
strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_val,
num_bins=cfg.stratification_bins,
logger=logger,
)
val_dataset = GeoFMSubset(val_dataset, indices)
raw_val_dataset = GeoFMSubset(raw_val_dataset, indices)


train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor, cfg.data_replicate)
val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor, cfg.data_replicate)

logger.info("Built {} dataset.".format(cfg.dataset.dataset_name))

logger.info(
f"Total number of train patches: {len(train_dataset)}\n"
Expand Down

0 comments on commit 52fd21a

Please sign in to comment.