Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified sampling #83

Merged
merged 22 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
73ca7ab
add stratified sampling to training set
alishibli97 Sep 24, 2024
00ce22e
add function in hls class
alishibli97 Sep 26, 2024
cd3ce0c
adding startification
alishibli97 Oct 2, 2024
90e8c3c
add geofmsubset class
alishibli97 Oct 2, 2024
b3a5a1a
add val stratification and logging info
alishibli97 Oct 2, 2024
ed59f03
Merge remote-tracking branch 'origin/main' into stratified_sampling
alishibli97 Oct 2, 2024
f709e35
update readme
alishibli97 Oct 2, 2024
8fbc77e
Merge remote-tracking branch 'origin/main' into stratified_sampling
yurujaja Oct 7, 2024
f396083
re-add hlsburn train-val split
yurujaja Oct 7, 2024
8080464
limited label for both train and val, random or stratified sampling
yurujaja Oct 7, 2024
f25a445
add regression stratification
alishibli97 Oct 8, 2024
33053c2
Updated "stratify_regression_dataset_indices" function to return frac…
RituYadav92 Oct 9, 2024
927eec6
adding segmentation stratification
alishibli97 Oct 10, 2024
04f783b
deep copy ckpt
yurujaja Oct 10, 2024
8a4e365
Merge remote-tracking branch 'origin' into stratified_sampling
yurujaja Oct 10, 2024
b3b1378
synched the steps between classification and regression
RituYadav92 Oct 10, 2024
bdc7dd6
enable stratified sampling and oversampling
yurujaja Oct 10, 2024
32429d7
fix conflict
yurujaja Oct 10, 2024
28238db
add docstring
yurujaja Oct 10, 2024
afa76a0
Update README.md
VMarsocci Oct 10, 2024
9159a83
Added comment to guide oversampling for biomass or regression in general
RituYadav92 Oct 10, 2024
424535c
Update a comment
yurujaja Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We provide several ways to install the dependencies.

## 🏋️ Training

To run experiments, please refer to `configs/train.yaml`. In it, in addition to some basic info about training (e.g. `finetune` for fine-tuning also the encoder, `limited_label` to train the model on a subset of labels, `num_workers`, `batch_size` and so on), there are 5 different basic configs:
To run experiments, please refer to `configs/train.yaml`. In it, in addition to some basic info about training (e.g. `finetune` for fine-tuning also the encoder, `limited_label_train` to train the model on a stratified subset of labels, `num_workers`, `batch_size` and so on), there are 5 different basic configs:
- `dataset`: Information of downstream datasets such as image size, band_statistics, classes etc.
- `decoder`: Downstream task decoder fine-tuning related parameters, like the type of architecture (e.g. UPerNet), which multi-temporal strategy to use, and other related hparams (e.g. nr of channels)
- `encoder`: GFM encoder related parameters. `output_layers` is used for which layers are used for Upernet decoder.
Expand Down Expand Up @@ -136,7 +136,7 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=segmentation
```

If you want to overwrite some parameters (e.g. turn off wandbe, and changing batch size and the path to the dataset):
If you want to overwrite some parameters (e.g. turn off wandbe, change batch size and the path to the dataset, and use 50% stratified sampled subset for training):
```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
--config-name=train \
Expand All @@ -148,7 +148,9 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=segmentation \
dataset.root_path= /path/to/the/dataset/hlsburnscars \
batch_size=16 \
use_wandb=False
use_wandb=False \
limited_label_train=0.5 \
limited_label_strategy=stratified
```

#### Multi-Temporal Semantic Segmentation
Expand Down
6 changes: 5 additions & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ batch_size: 32
# EXPERIMENT
finetune: false
ckpt_dir: null
limited_label: 1
limited_label_train: 1
limited_label_val: 1
limited_label_strategy: stratified_classification # stratified_regression, random
stratification_bins: 3 # number of bins for stratified sampling, only for stratified


defaults:
- task: ???
Expand Down
33 changes: 32 additions & 1 deletion pangaea/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, Subset
import os

class GeoFMDataset(Dataset):
Expand Down Expand Up @@ -115,3 +115,34 @@ def download(self) -> None:
NotImplementedError: raise if the method is not implemented
"""
raise NotImplementedError


class GeoFMSubset(Subset):
"""Custom subset class that retains dataset attributes."""

def __init__(self, dataset, indices):
super().__init__(dataset, indices)

# Copy relevant attributes from the original dataset
self.dataset_name = getattr(dataset, 'dataset_name', None)
self.root_path = getattr(dataset, 'root_path', None)
self.auto_download = getattr(dataset, 'auto_download', None)
self.download_url = getattr(dataset, 'download_url', None)
self.img_size = getattr(dataset, 'img_size', None)
self.multi_temporal = getattr(dataset, 'multi_temporal', None)
self.multi_modal = getattr(dataset, 'multi_modal', None)
self.ignore_index = getattr(dataset, 'ignore_index', None)
self.num_classes = getattr(dataset, 'num_classes', None)
self.classes = getattr(dataset, 'classes', None)
self.distribution = getattr(dataset, 'distribution', None)
self.bands = getattr(dataset, 'bands', None)
self.data_mean = getattr(dataset, 'data_mean', None)
self.data_std = getattr(dataset, 'data_std', None)
self.data_min = getattr(dataset, 'data_min', None)
self.data_max = getattr(dataset, 'data_max', None)
self.split = getattr(dataset, 'split', None)

def filter_by_indices(self, indices):
"""Apply filtering by indices directly in this subset."""
return GeoFMSubset(self.dataset, indices)

29 changes: 4 additions & 25 deletions pangaea/datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@
import time
import torch
import numpy as np
# import rasterio
import tifffile as tiff
from typing import Sequence, Dict, Any, Union, Literal, Tuple
from typing import Sequence, Tuple
from sklearn.model_selection import train_test_split
from glob import glob

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

import pathlib
import urllib
import tarfile

# from utils.registry import DATASET_REGISTRY
from pangaea.datasets.utils import DownloadProgressBar
from pangaea.datasets.base import GeoFMDataset

# @DATASET_REGISTRY.register()

class HLSBurnScars(GeoFMDataset):
def __init__(
self,
Expand Down Expand Up @@ -143,7 +139,6 @@ def __len__(self):
return len(self.image_list)

def __getitem__(self, index):

image = tiff.imread(self.image_list[index])
image = image.astype(np.float32) # Convert to float32
image = torch.from_numpy(image).permute(2, 0, 1)
Expand All @@ -155,30 +150,16 @@ def __getitem__(self, index):
invalid_mask = image == 9999
image[invalid_mask] = 0


output = {
'image': {
'optical': image,
},
'target': target,
'metadata': {}
}

return output


@staticmethod
def get_stratified_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[int]]:
return output

# Fixed stratified sample to split data into train/val.
# This keeps 90% of datapoints belonging to an individual event in the training set and puts the remaining 10% in the validation set.
disaster_names = list(
map(lambda path: pathlib.Path(path).name.split("_")[0], all_files))
train_idxs, val_idxs = train_test_split(np.arange(len(all_files)),
test_size=0.1,
random_state=23,
stratify=disaster_names)
return {"train": train_idxs, "val": val_idxs}

@staticmethod
def download(self, silent=False):
Expand Down Expand Up @@ -211,6 +192,4 @@ def download(self, silent=False):
tar.extractall(output_path)
print("done.")

os.remove(output_path / temp_file_name)


os.remove(output_path / temp_file_name)
32 changes: 22 additions & 10 deletions pangaea/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from pangaea.decoders.base import Decoder
Expand All @@ -25,6 +25,8 @@
get_generator,
seed_worker,
)
from pangaea.utils.subset_sampler import get_subset_indices
from pangaea.datasets.base import GeoFMSubset


def get_exp_name(hydra_config: HydraConf) -> str:
Expand Down Expand Up @@ -138,6 +140,7 @@ def main(cfg: DictConfig) -> None:

# training
if train_run:

for preprocess in cfg.preprocessing.train:
train_dataset: Dataset = instantiate(
preprocess, dataset=train_dataset, encoder=encoder
Expand All @@ -146,17 +149,25 @@ def main(cfg: DictConfig) -> None:
val_dataset: Dataset = instantiate(
preprocess, dataset=val_dataset, encoder=encoder
)
if 0 < cfg.limited_label < 1:
n_train_samples = len(train_dataset)
indices = random.sample(
range(n_train_samples), int(n_train_samples * cfg.limited_label)

if 0 < cfg.limited_label_train < 1:
indices = get_subset_indices(
train_dataset, strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_train, num_bins=cfg.stratification_bins, logger=logger
)
train_dataset = GeoFMSubset(train_dataset, indices)

if 0 < cfg.limited_label_val < 1:
indices = get_subset_indices(
val_dataset, strategy=cfg.limited_label_strategy,
label_fraction=cfg.limited_label_val, num_bins=cfg.stratification_bins, logger=logger
)
train_dataset = Subset(train_dataset, indices)
logger.info(
f"Created a subset of the train dataset, with {cfg.limited_label * 100}% of the labels available"
val_dataset = GeoFMSubset(val_dataset, indices)

logger.info(
f"Total number of train patches: {len(train_dataset)}\n"
f"Total number of validation patches: {len(val_dataset)}\n"
)
else:
logger.info("The entire train dataset will be used.")

# get train val data loaders
train_loader = DataLoader(
Expand All @@ -172,6 +183,7 @@ def main(cfg: DictConfig) -> None:
drop_last=True,
collate_fn=collate_fn,
)

val_loader = DataLoader(
val_dataset,
sampler=DistributedSampler(val_dataset),
Expand Down
133 changes: 133 additions & 0 deletions pangaea/utils/subset_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import random
from tqdm import tqdm
import numpy as np

# Function to calculate class distributions for classification with a progress bar
def calculate_class_distributions(dataset, num_classes):
class_distributions = []

# Adding a progress bar for dataset processing
for idx in tqdm(range(len(dataset)), desc="Calculating class distributions per sample"):
target = dataset[idx]['target']
total_pixels = target.numel()
class_counts = [(target == i).sum().item() for i in range(num_classes)]
class_ratios = [count / total_pixels for count in class_counts]
class_distributions.append(class_ratios)

return np.array(class_distributions)


# Function to calculate distribution metrics for regression
def calculate_regression_distributions(dataset):
distributions = []

# Adding a progress bar for dataset processing
for idx in tqdm(range(len(dataset)), desc="Calculating regression distributions per sample"):
target = dataset[idx]['target']
mean_value = target.mean().item() # Example for patch-wise mean; adjust as needed for other metrics
distributions.append(mean_value)

return np.array(distributions)


# Function to bin class distributions with a progress bar
def bin_class_distributions(class_distributions, num_bins=3, logger=None):
logger.info(f"Class distributions are being binned into {num_bins} categories")
# Adding a progress bar for binning class distributions
binned_distributions = np.digitize(class_distributions, np.linspace(0, 1, num_bins+1)) - 1
return binned_distributions


# Function to bin regression distributions with a progress bar
def bin_regression_distributions(regression_distributions, num_bins=3, logger=None):
logger.info(f"Regression distributions are being binned into {num_bins} categories")
# Define the range for binning based on minimum and maximum values in regression distributions
binned_distributions = np.digitize(
regression_distributions,
np.linspace(regression_distributions.min(), regression_distributions.max(), num_bins + 1)
) - 1
return binned_distributions


# Updated function to perform stratification for classification and return only the indices, with even bin selection
def stratify_classification_dataset_indices(dataset, num_classes, label_fraction=1.0, num_bins=3, logger=None):
# Step 1: Calculate class distributions with progress tracking
class_distributions = calculate_class_distributions(dataset, num_classes)

# Step 2: Bin the class distributions
binned_distributions = bin_class_distributions(class_distributions, num_bins=num_bins, logger=logger)

# Step 3: Prep a dictionary to hold indices for each bin combination
indices_per_bin = {}

# Combine the bins for each class to create unique bin identifiers
combined_bins = np.apply_along_axis(lambda row: ''.join(map(str, row)), axis=1, arr=binned_distributions)

# Step 4: Populate the dictionary with indices based on combined bin identifiers
for idx, bin_id in enumerate(combined_bins):
if bin_id not in indices_per_bin:
RituYadav92 marked this conversation as resolved.
Show resolved Hide resolved
indices_per_bin[bin_id] = []
indices_per_bin[bin_id].append(idx)

# Step 5: Select a fraction of indices from each bin
selected_idx = []
for bin_id, indices in indices_per_bin.items():
num_to_select = int(max(1, len(indices) * label_fraction)) # Ensure at least one index is selected
selected_idx.extend(np.random.choice(indices, num_to_select, replace=False))

# Step 6: List the remaining unselected indices
other_idx = list(set(range(len(dataset))) - set(selected_idx))

return selected_idx, other_idx


# Function to perform stratification for regression and return only the indices
def stratify_regression_dataset_indices(dataset, label_fraction=1.0, num_bins=3, logger=None):
# Step 1: Calculate regression distributions with progress tracking
regression_distributions = calculate_regression_distributions(dataset)

# Step 2: Bin the regression distributions
binned_distributions = bin_regression_distributions(regression_distributions, num_bins=num_bins, logger=logger)

# Step 3: Prep a dictionary to hold indices for each bin
indices_per_bin = {i: [] for i in range(num_bins)}

# Step 4: Populate the indices per bin
for index, bin_index in enumerate(binned_distributions):
if bin_index in indices_per_bin:
indices_per_bin[bin_index].append(index)

# Step 5: Select a fraction of indices from each bin
selected_idx = []
for bin_index, indices in indices_per_bin.items():
num_to_select = int(max(1, len(indices) * label_fraction)) # Ensure at least one index is selected
selected_idx.extend(np.random.choice(indices, num_to_select, replace=False))

# Step 6: List the remaining unselected indices
other_idx = list(set(range(len(dataset))) - set(selected_idx))

return selected_idx, other_idx


# Function to get subset indices based on the strategy, supporting both classification and regression
def get_subset_indices(dataset, strategy="random", label_fraction=0.5, num_bins=3, logger=None):
logger.info(
f"Creating a subset of the {dataset.split} dataset using {strategy} strategy, with {label_fraction * 100}% of labels utilized."
)
if strategy == "stratified_classification":
indices, _ = stratify_classification_dataset_indices(
dataset, num_classes=dataset.num_classes, label_fraction=label_fraction, num_bins=num_bins, logger=logger
)
elif strategy == "stratified_regression":
indices, _ = stratify_regression_dataset_indices(
dataset, label_fraction=label_fraction, num_bins=num_bins, logger=logger
)
else: # Default to random sampling
n_samples = len(dataset)
indices = random.sample(
range(n_samples), int(n_samples * label_fraction)
)

return indices