Skip to content

Commit

Permalink
Stratified sampling (#83)
Browse files Browse the repository at this point in the history
* add dataset subsampling based on the task type and strategy, support random/stratified/oversampled strategy

* add geofmsubset class

* add val stratification and logging info

* update readme

* update get_exp_info function in run.py

* add regression stratification

* Updated "stratify_regression_dataset_indices" function to return fraction of labels from each bin

Previous code: A fraction of labels were selected from the sorted values. Specifically, for biomass, it was selecting samples with the lowest biomass.

* deep copy ckpt

* add docstring

* Added comment to guide oversampling for biomass or regression in general

---------

Co-authored-by: yurujaja <[email protected]>
Co-authored-by: Ritu Yadav <[email protected]>
Co-authored-by: Valerio Marsocci <[email protected]>
  • Loading branch information
4 people authored Oct 10, 2024
1 parent a04f8bc commit 5794b4c
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 50 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We provide several ways to install the dependencies.
## 🏋️ Training
To run experiments, please refer to `configs/train.yaml`. In it, in addition to some basic info about training (e.g. `finetune` for fine-tuning also the encoder, `limited_label` to train the model on a subset of labels, `num_workers`, `batch_size` and so on), there are 5 different basic configs:
To run experiments, please refer to `configs/train.yaml`. In it, in addition to some basic info about training (e.g. `finetune` for fine-tuning also the encoder, `limited_label_train` to train the model on a stratified subset of labels, `num_workers`, `batch_size` and so on), there are 5 different basic configs:
- `dataset`: Information of downstream datasets such as image size, band_statistics, classes etc.
- `decoder`: Downstream task decoder fine-tuning related parameters, like the type of architecture (e.g. UPerNet), which multi-temporal strategy to use, and other related hparams (e.g. nr of channels)
- `encoder`: GFM encoder related parameters. `output_layers` is used for which layers are used for Upernet decoder.
Expand Down Expand Up @@ -136,7 +136,7 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=segmentation
```
If you want to overwrite some parameters (e.g. turn off wandbe, and changing batch size and the path to the dataset):
If you want to overwrite some parameters (e.g. turn off wandbe, change batch size and the path to the dataset, and use 50% stratified sampled subset for training):
```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
--config-name=train \
Expand All @@ -148,7 +148,9 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=segmentation \
dataset.root_path= /path/to/the/dataset/hlsburnscars \
batch_size=16 \
use_wandb=False
use_wandb=False \
limited_label_train=0.5 \
limited_label_strategy=stratified
```
#### Multi-Temporal Semantic Segmentation
Expand Down Expand Up @@ -263,7 +265,7 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
--config-name=train \
dataset=sen1floods11 \
encoder=unet_encoder \
decoder=unet \
decoder=seg_unet \
preprocessing=seg_default \
criterion=cross_entropy \
task=segmentation \
Expand Down
7 changes: 6 additions & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ batch_size: 32
# EXPERIMENT
finetune: false
ckpt_dir: null
limited_label: 1
limited_label_train: 1
limited_label_val: 1
limited_label_strategy: stratified # Options: stratified, oversampled, random
stratification_bins: 3 # number of bins for stratified sampling, only for stratified



defaults:
- task: ???
Expand Down
33 changes: 32 additions & 1 deletion pangaea/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, Subset
import os

class GeoFMDataset(Dataset):
Expand Down Expand Up @@ -115,3 +115,34 @@ def download(self) -> None:
NotImplementedError: raise if the method is not implemented
"""
raise NotImplementedError


class GeoFMSubset(Subset):
"""Custom subset class that retains dataset attributes."""

def __init__(self, dataset, indices):
super().__init__(dataset, indices)

# Copy relevant attributes from the original dataset
self.dataset_name = getattr(dataset, 'dataset_name', None)
self.root_path = getattr(dataset, 'root_path', None)
self.auto_download = getattr(dataset, 'auto_download', None)
self.download_url = getattr(dataset, 'download_url', None)
self.img_size = getattr(dataset, 'img_size', None)
self.multi_temporal = getattr(dataset, 'multi_temporal', None)
self.multi_modal = getattr(dataset, 'multi_modal', None)
self.ignore_index = getattr(dataset, 'ignore_index', None)
self.num_classes = getattr(dataset, 'num_classes', None)
self.classes = getattr(dataset, 'classes', None)
self.distribution = getattr(dataset, 'distribution', None)
self.bands = getattr(dataset, 'bands', None)
self.data_mean = getattr(dataset, 'data_mean', None)
self.data_std = getattr(dataset, 'data_std', None)
self.data_min = getattr(dataset, 'data_min', None)
self.data_max = getattr(dataset, 'data_max', None)
self.split = getattr(dataset, 'split', None)

def filter_by_indices(self, indices):
"""Apply filtering by indices directly in this subset."""
return GeoFMSubset(self.dataset, indices)

29 changes: 4 additions & 25 deletions pangaea/datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@
import time
import torch
import numpy as np
# import rasterio
import tifffile as tiff
from typing import Sequence, Dict, Any, Union, Literal, Tuple
from typing import Sequence, Tuple
from sklearn.model_selection import train_test_split
from glob import glob

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

import pathlib
import urllib
import tarfile

# from utils.registry import DATASET_REGISTRY
from pangaea.datasets.utils import DownloadProgressBar
from pangaea.datasets.base import GeoFMDataset

# @DATASET_REGISTRY.register()

class HLSBurnScars(GeoFMDataset):
def __init__(
self,
Expand Down Expand Up @@ -143,7 +139,6 @@ def __len__(self):
return len(self.image_list)

def __getitem__(self, index):

image = tiff.imread(self.image_list[index])
image = image.astype(np.float32) # Convert to float32
image = torch.from_numpy(image).permute(2, 0, 1)
Expand All @@ -155,30 +150,16 @@ def __getitem__(self, index):
invalid_mask = image == 9999
image[invalid_mask] = 0


output = {
'image': {
'optical': image,
},
'target': target,
'metadata': {}
}

return output


@staticmethod
def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[int]]:
return output

# Fixed stratified sample to split data into train/val.
# This keeps 90% of datapoints belonging to an individual event in the training set and puts the remaining 10% in the validation set.
disaster_names = list(
map(lambda path: pathlib.Path(path).name.split("_")[0], all_files))
train_idxs, val_idxs = train_test_split(np.arange(len(all_files)),
test_size=0.1,
random_state=23,
stratify=disaster_names)
return {"train": train_idxs, "val": val_idxs}

@staticmethod
def download(self, silent=False):
Expand Down Expand Up @@ -211,6 +192,4 @@ def download(self, silent=False):
tar.extractall(output_path)
print("done.")

os.remove(output_path / temp_file_name)


os.remove(output_path / temp_file_name)
8 changes: 3 additions & 5 deletions pangaea/engine/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import operator
import os
Expand Down Expand Up @@ -77,10 +78,7 @@ def __init__(
self.training_metrics = {}
self.best_ckpt = None
self.best_metric_comp = operator.gt
if isinstance(self.train_loader.dataset, Subset):
self.num_classes = self.train_loader.dataset.dataset.num_classes
else:
self.num_classes = self.train_loader.dataset.num_classes
self.num_classes = self.train_loader.dataset.num_classes

assert precision in [
"fp32",
Expand Down Expand Up @@ -200,7 +198,7 @@ def get_checkpoint(self, epoch: int) -> dict[str, dict | int]:
"scaler": self.scaler.state_dict(),
"epoch": epoch,
}
return checkpoint
return copy.deepcopy(checkpoint)

def save_model(
self,
Expand Down
51 changes: 37 additions & 14 deletions pangaea/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from pangaea.decoders.base import Decoder
Expand All @@ -25,23 +25,34 @@
get_generator,
seed_worker,
)
from pangaea.utils.subset_sampler import get_subset_indices
from pangaea.datasets.base import GeoFMSubset


def get_exp_name(hydra_config: HydraConf) -> str:
def get_exp_info(hydra_config: HydraConf) -> str:
"""Create a unique experiment name based on the choices made in the config.
Args:
hydra_config (HydraConf): hydra config.
Returns:
str: experiment name.
str: experiment information.
"""
choices = OmegaConf.to_container(hydra_config.runtime.choices)
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
fm = choices["encoder"]
decoder = choices["decoder"]
ds = choices["dataset"]
return f"{timestamp}-{fm}-{decoder}-{ds}"
task = choices["task"]
exp_info = {
"timestamp": timestamp,
"fm": fm,
"decoder": decoder,
"ds": ds,
"task": task,
"exp_name": f"{timestamp}_{fm}_{decoder}_{ds}",
}
return exp_info


@hydra.main(version_base=None, config_path="../configs", config_name="train")
Expand All @@ -64,7 +75,9 @@ def main(cfg: DictConfig) -> None:
# true if training else false
train_run = cfg.train
if train_run:
exp_name = get_exp_name(HydraConfig.get())
exp_info = get_exp_info(HydraConfig.get())
exp_name = exp_info["exp_name"]
task_name = exp_info["task"]
exp_dir = pathlib.Path(cfg.work_dir) / exp_name
exp_dir.mkdir(parents=True, exist_ok=True)
logger_path = exp_dir / "train.log"
Expand Down Expand Up @@ -138,6 +151,7 @@ def main(cfg: DictConfig) -> None:

# training
if train_run:

for preprocess in cfg.preprocessing.train:
train_dataset: Dataset = instantiate(
preprocess, dataset=train_dataset, encoder=encoder
Expand All @@ -146,17 +160,25 @@ def main(cfg: DictConfig) -> None:
val_dataset: Dataset = instantiate(
preprocess, dataset=val_dataset, encoder=encoder
)
if 0 < cfg.limited_label < 1:
n_train_samples = len(train_dataset)
indices = random.sample(
range(n_train_samples), int(n_train_samples * cfg.limited_label)

if 0 < cfg.limited_label_train < 1:
indices = get_subset_indices(
train_dataset, task=task_name, strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_train, num_bins=cfg.stratification_bins, logger=logger
)
train_dataset = GeoFMSubset(train_dataset, indices)

if 0 < cfg.limited_label_val < 1:
indices = get_subset_indices(
val_dataset, task=task_name, strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_val, num_bins=cfg.stratification_bins, logger=logger
)
train_dataset = Subset(train_dataset, indices)
logger.info(
f"Created a subset of the train dataset, with {cfg.limited_label * 100}% of the labels available"
val_dataset = GeoFMSubset(val_dataset, indices)

logger.info(
f"Total number of train patches: {len(train_dataset)}\n"
f"Total number of validation patches: {len(val_dataset)}\n"
)
else:
logger.info("The entire train dataset will be used.")

# get train val data loaders
train_loader = DataLoader(
Expand All @@ -172,6 +194,7 @@ def main(cfg: DictConfig) -> None:
drop_last=True,
collate_fn=collate_fn,
)

val_loader = DataLoader(
val_dataset,
sampler=DistributedSampler(val_dataset),
Expand Down
Loading

0 comments on commit 5794b4c

Please sign in to comment.