diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 44e9d52e..20104923 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -177,16 +177,16 @@ We have designed the repo to allow for using your own datasets with minimal effo 1. **Implement a Dataset Class**: - In the `pangaea/datasets/` directory, create a new Python file named after your dataset (e.g., `my_dataset.py`). - - Implement a class that inherits from `GeoFMDataset`. You can check it in `pangaea/datasets/base.py`. + - Implement a class that inherits from `RawGeoFMDataset`. You can check it in `pangaea/datasets/base.py`. - Be sure that your dataset is instantiated with all the required parameters from the `GeoFMDataset`. You can also add new parameters. - Implement the required methods: `__init__`, `__len__`, `__getitem__`, and `download` (if applicable, otherwise a `NotImplementedError is raised`). - **Example**: ```python import torch - from pangaea.datasets.base import GeoFMDataset + from pangaea.datasets.base import RawGeoFMDataset - class MyDataset(GeoFMDataset): + class MyDataset(RawGeoFMDataset): def __init__( self, split: str, @@ -236,6 +236,25 @@ We have designed the repo to allow for using your own datasets with minimal effo return len(self.file_list) def __getitem__(self, index): + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + }, + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., + "metadata": dict}. + """ # Load your data and labels here image = ... # Load image target = ... # Load target label or mask diff --git a/.gitignore b/.gitignore index 93c03efc..b0cffadb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ **/__pycache__/ *.egg-info +aaa_mystuff/ +work-dir/ outputs/ old_files/ pretrained/ @@ -8,6 +10,8 @@ pretrained_models/ work-dir/ wandb/ data/* +data +pretrained_models # Scenes for mados tiny # Wow gitignore pattern mathcing is a mess !data/mados diff --git a/configs/dataset/dynamicen.yaml b/configs/dataset/dynamicen.yaml index 7e443271..88ff361a 100644 --- a/configs/dataset/dynamicen.yaml +++ b/configs/dataset/dynamicen.yaml @@ -26,6 +26,14 @@ distribution: - 0.28 - 0.08 +sample_dates: + - 1 + - 5 + - 10 + - 15 + - 20 + - 25 + # data stats bands: optical: @@ -36,15 +44,15 @@ bands: data_mean: optical: - - 1042.59240722656 - - 915.618408203125 - 671.260559082031 + - 915.618408203125 + - 1042.59240722656 - 2605.20922851562 data_std: optical: - - 957.958435058593 - - 715.548767089843 - 596.943908691406 + - 715.548767089843 + - 957.958435058593 - 1059.90319824218 data_min: diff --git a/configs/dataset/pastis.yaml b/configs/dataset/pastis.yaml index 7bc2092e..9b292de5 100644 --- a/configs/dataset/pastis.yaml +++ b/configs/dataset/pastis.yaml @@ -10,8 +10,8 @@ multi_modal: True #limited_label: False # classes -ignore_index: 0 -num_classes: 19 +ignore_index: 19 +num_classes: 20 classes: - Background - Meadow @@ -32,7 +32,7 @@ classes: - Orchard - Mixed Cereal - Sorghum - #- Void Label + - Void Label distribution: - 0.00000 - 0.25675 @@ -53,7 +53,7 @@ distribution: - 0.02460 - 0.00696 - 0.00580 - #- 0.29476 + - 0.29476 bands: optical: diff --git a/configs/encoder/croma_joint.yaml b/configs/encoder/croma_joint.yaml index dfc68a67..cfecf55e 100644 --- a/configs/encoder/croma_joint.yaml +++ b/configs/encoder/croma_joint.yaml @@ -28,3 +28,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 1024 diff --git a/configs/encoder/croma_optical.yaml b/configs/encoder/croma_optical.yaml index 8615598d..7f2a7312 100644 --- a/configs/encoder/croma_optical.yaml +++ b/configs/encoder/croma_optical.yaml @@ -24,3 +24,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/croma_sar.yaml b/configs/encoder/croma_sar.yaml index 42925f6b..10dbecd6 100644 --- a/configs/encoder/croma_sar.yaml +++ b/configs/encoder/croma_sar.yaml @@ -14,4 +14,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/dofa.yaml b/configs/encoder/dofa.yaml index 89a087f5..3daf76b0 100644 --- a/configs/encoder/dofa.yaml +++ b/configs/encoder/dofa.yaml @@ -41,3 +41,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/gfmswin.yaml b/configs/encoder/gfmswin.yaml index 5e4c6fd7..a3d564b5 100644 --- a/configs/encoder/gfmswin.yaml +++ b/configs/encoder/gfmswin.yaml @@ -7,7 +7,6 @@ in_chans: 3 t_patch_size: 3 depth: 12 embed_dim: 128 -output_dim: 1024 img_size: 192 # fixed to 192 to avoid interpolation in checkpoints which leads to drop in performance depths: [ 2, 2, 18, 2 ] num_heads: [ 4, 8, 16, 32 ] @@ -28,3 +27,8 @@ output_layers: - 2 - 3 +output_dim: + - 256 + - 512 + - 1024 + - 1024 diff --git a/configs/encoder/prithvi.yaml b/configs/encoder/prithvi.yaml index e7da1e2c..d82d6681 100644 --- a/configs/encoder/prithvi.yaml +++ b/configs/encoder/prithvi.yaml @@ -27,3 +27,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/remoteclip.yaml b/configs/encoder/remoteclip.yaml index 46eb3dfd..3288ed19 100644 --- a/configs/encoder/remoteclip.yaml +++ b/configs/encoder/remoteclip.yaml @@ -20,4 +20,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/satlasnet_mi.yaml b/configs/encoder/satlasnet_mi.yaml index 2383bdf0..ca3ae89f 100644 --- a/configs/encoder/satlasnet_mi.yaml +++ b/configs/encoder/satlasnet_mi.yaml @@ -5,7 +5,6 @@ download_url: null model_identifier: Sentinel2_SwinB_MI_MS # Multi-Image Multi-Spectral fpn: False input_size: 128 -output_dim: 1024 input_bands: optical: @@ -23,3 +22,15 @@ input_bands: - B8 - B11 - B12 + +output_layers: + - 0 + - 1 + - 2 + - 3 + +output_dim: + - 128 + - 256 + - 512 + - 1024 \ No newline at end of file diff --git a/configs/encoder/satlasnet_si.yaml b/configs/encoder/satlasnet_si.yaml index 5759fcb4..d7958c02 100644 --- a/configs/encoder/satlasnet_si.yaml +++ b/configs/encoder/satlasnet_si.yaml @@ -5,7 +5,6 @@ download_url: null model_identifier: Sentinel2_SwinB_SI_MS # Single Image Multi-Spectral fpn: False input_size: 128 -output_dim: 1024 input_bands: optical: @@ -23,3 +22,15 @@ input_bands: - B8 - B11 - B12 + +output_layers: + - 0 + - 1 + - 2 + - 3 + +output_dim: + - 128 + - 256 + - 512 + - 1024 \ No newline at end of file diff --git a/configs/encoder/scalemae.yaml b/configs/encoder/scalemae.yaml index abe99595..4790d512 100644 --- a/configs/encoder/scalemae.yaml +++ b/configs/encoder/scalemae.yaml @@ -25,3 +25,5 @@ output_layers: - 11 - 15 - 23 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/spectralgpt.yaml b/configs/encoder/spectralgpt.yaml index e880118a..04f4058c 100644 --- a/configs/encoder/spectralgpt.yaml +++ b/configs/encoder/spectralgpt.yaml @@ -3,7 +3,7 @@ encoder_weights: ./pretrained_models/SpectralGPT+.pth download_url: https://zenodo.org/records/8412455/files/SpectralGPT+.pth input_size: 128 -output_dim: 3072 # 768 * (in_chans / t_patch_size) + in_chans: 12 # number of spectral bands t_patch_size: 3 depth: 12 @@ -34,4 +34,5 @@ output_layers: - 7 - 11 +output_dim: 3072 # 768 * (in_chans / t_patch_size) diff --git a/configs/encoder/ssl4eo_data2vec.yaml b/configs/encoder/ssl4eo_data2vec.yaml index 0f6bcb3a..4c38b13b 100644 --- a/configs/encoder/ssl4eo_data2vec.yaml +++ b/configs/encoder/ssl4eo_data2vec.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_dino.yaml b/configs/encoder/ssl4eo_dino.yaml index b69ff2a2..f006119b 100644 --- a/configs/encoder/ssl4eo_dino.yaml +++ b/configs/encoder/ssl4eo_dino.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_mae_optical.yaml b/configs/encoder/ssl4eo_mae_optical.yaml index 51c43eda..2caac534 100644 --- a/configs/encoder/ssl4eo_mae_optical.yaml +++ b/configs/encoder/ssl4eo_mae_optical.yaml @@ -31,4 +31,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_mae_sar.yaml b/configs/encoder/ssl4eo_mae_sar.yaml index a3ac0113..a285940a 100644 --- a/configs/encoder/ssl4eo_mae_sar.yaml +++ b/configs/encoder/ssl4eo_mae_sar.yaml @@ -19,4 +19,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_moco.yaml b/configs/encoder/ssl4eo_moco.yaml index 200e8a08..6f7ddbc2 100644 --- a/configs/encoder/ssl4eo_moco.yaml +++ b/configs/encoder/ssl4eo_moco.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 diff --git a/configs/encoder/unet_encoder.yaml b/configs/encoder/unet_encoder.yaml index b3bd2450..1db7e09a 100644 --- a/configs/encoder/unet_encoder.yaml +++ b/configs/encoder/unet_encoder.yaml @@ -6,3 +6,9 @@ topology: [64, 128, 256, 512,] input_bands: ${dataset.bands} +output_dim: + - 64 + - 128 + - 256 + - 512 + diff --git a/configs/preprocessing/reg_default.yaml b/configs/preprocessing/reg_default.yaml index bd199fca..1d468b89 100644 --- a/configs/preprocessing/reg_default.yaml +++ b/configs/preprocessing/reg_default.yaml @@ -1,35 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.RegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax - # overwritten in run - dataset: null - encoder: null - data_min: -1 - data_max: 1 +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.RegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax - # overwritten in run - dataset: null - encoder: null - data_min: -1 - data_max: 1 - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_default.yaml b/configs/preprocessing/seg_default.yaml index d6d40842..1d468b89 100644 --- a/configs/preprocessing/seg_default.yaml +++ b/configs/preprocessing/seg_default.yaml @@ -1,31 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_importance_crop.yaml b/configs/preprocessing/seg_importance_crop.yaml index f2147762..740e5ab7 100644 --- a/configs/preprocessing/seg_importance_crop.yaml +++ b/configs/preprocessing/seg_importance_crop.yaml @@ -1,31 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ImportanceRandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.ImportanceRandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_resize.yaml b/configs/preprocessing/seg_resize.yaml index aaa37718..625a5e70 100644 --- a/configs/preprocessing/seg_resize.yaml +++ b/configs/preprocessing/seg_resize.yaml @@ -1,32 +1,23 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/task/change_detection.yaml b/configs/task/change_detection.yaml index a4a64133..179e85ef 100644 --- a/configs/task/change_detection.yaml +++ b/configs/task/change_detection.yaml @@ -26,5 +26,7 @@ evaluator: exp_dir: null device: null use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 diff --git a/configs/task/regression.yaml b/configs/task/regression.yaml index 623733b2..f596b59b 100644 --- a/configs/task/regression.yaml +++ b/configs/task/regression.yaml @@ -25,4 +25,6 @@ evaluator: val_loader: null exp_dir: null device: null - use_wandb: ${use_wandb} \ No newline at end of file + use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 \ No newline at end of file diff --git a/configs/task/segmentation.yaml b/configs/task/segmentation.yaml index ff1a7404..b41224ad 100644 --- a/configs/task/segmentation.yaml +++ b/configs/task/segmentation.yaml @@ -26,5 +26,7 @@ evaluator: exp_dir: null device: null use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 diff --git a/configs/train.yaml b/configs/train.yaml index 239db1f8..d447351c 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,9 +5,11 @@ use_wandb: false wandb_run_id: null # TRAINING -num_workers: 1 -batch_size: 32 +num_workers: 4 +batch_size: 8 +test_num_workers: 4 +test_batch_size: 1 # EXPERIMENT finetune: false diff --git a/pangaea/datasets/ai4smallfarms.py b/pangaea/datasets/ai4smallfarms.py index 8436de13..7e8c9b1b 100644 --- a/pangaea/datasets/ai4smallfarms.py +++ b/pangaea/datasets/ai4smallfarms.py @@ -8,10 +8,10 @@ from pyDataverse.api import DataAccessApi, NativeApi from tifffile import imread -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class AI4SmallFarms(GeoFMDataset): +class AI4SmallFarms(RawGeoFMDataset): def __init__( self, split: str, @@ -121,6 +121,9 @@ def __getitem__(self, index): invalid_mask = torch.isnan(image) image[invalid_mask] = 0 + # output image shape (C T=1 H W) + image = image.unsqueeze(1) + # Convert target to a boolean tensor target = target.bool() diff --git a/pangaea/datasets/base.py b/pangaea/datasets/base.py index f611b5c1..4a3bbe15 100644 --- a/pangaea/datasets/base.py +++ b/pangaea/datasets/base.py @@ -1,8 +1,12 @@ +import os + import torch from torch.utils.data import Dataset, Subset -import os -class GeoFMDataset(Dataset): +from pangaea.engine.data_preprocessor import Preprocessor + + +class RawGeoFMDataset(Dataset): """Base class for all datasets.""" def __init__( @@ -99,10 +103,11 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format {"image": { - "optical": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset), - "sar": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset) + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), }, - "target": torch.Tensor of shape (H W), + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., "metadata": dict}. """ raise NotImplementedError @@ -117,32 +122,70 @@ def download(self) -> None: raise NotImplementedError +class GeoFMDataset(Dataset): + """Base class for all datasets.""" + + def __init__( + self, + dataset: RawGeoFMDataset, + preprocessor: Preprocessor = None, + ): + """Initializes the dataset. + + Args: + + """ + super().__init__() + self.__dict__.update(dataset.__dict__) + self.raw_dataset = dataset + self.preprocessor = preprocessor + + def __len__(self) -> int: + """Returns the length of the dataset. + + Returns: + int: length of the dataset + """ + + return len(self.raw_dataset) + + def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + "sar": torch.Tensor of shape (C T H W) (where T=1 if single-temporal dataset), + }, + "target": torch.Tensor of shape (H W) of type torch.int64 for segmentation, torch.float for + regression datasets., + "metadata": dict}. + """ + + output = self.raw_dataset[i] + if self.preprocessor is not None: + output = self.preprocessor(output) + + return output + + class GeoFMSubset(Subset): """Custom subset class that retains dataset attributes.""" def __init__(self, dataset, indices): super().__init__(dataset, indices) - + # Copy relevant attributes from the original dataset - self.dataset_name = getattr(dataset, 'dataset_name', None) - self.root_path = getattr(dataset, 'root_path', None) - self.auto_download = getattr(dataset, 'auto_download', None) - self.download_url = getattr(dataset, 'download_url', None) - self.img_size = getattr(dataset, 'img_size', None) - self.multi_temporal = getattr(dataset, 'multi_temporal', None) - self.multi_modal = getattr(dataset, 'multi_modal', None) - self.ignore_index = getattr(dataset, 'ignore_index', None) - self.num_classes = getattr(dataset, 'num_classes', None) - self.classes = getattr(dataset, 'classes', None) - self.distribution = getattr(dataset, 'distribution', None) - self.bands = getattr(dataset, 'bands', None) - self.data_mean = getattr(dataset, 'data_mean', None) - self.data_std = getattr(dataset, 'data_std', None) - self.data_min = getattr(dataset, 'data_min', None) - self.data_max = getattr(dataset, 'data_max', None) - self.split = getattr(dataset, 'split', None) + self.__dict__.update(dataset.__dict__) def filter_by_indices(self, indices): """Apply filtering by indices directly in this subset.""" return GeoFMSubset(self.dataset, indices) - diff --git a/pangaea/datasets/biomassters.py b/pangaea/datasets/biomassters.py index 8e98b785..b895f06a 100644 --- a/pangaea/datasets/biomassters.py +++ b/pangaea/datasets/biomassters.py @@ -7,7 +7,7 @@ from os.path import join as opj from pangaea.datasets.utils import read_tif -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset def read_imgs(multi_temporal, temp , fname, data_dir, img_size): imgs_s1, imgs_s2, mask = [], [], [] @@ -49,7 +49,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size): imgs_s2 = np.stack(imgs_s2, axis=1) return imgs_s1, imgs_s2, mask -class BioMassters(GeoFMDataset): +class BioMassters(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/croptypemapping.py b/pangaea/datasets/croptypemapping.py index e59e4eec..b8bbe8d1 100644 --- a/pangaea/datasets/croptypemapping.py +++ b/pangaea/datasets/croptypemapping.py @@ -12,12 +12,12 @@ import torch -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # @DATASET_REGISTRY.register() -class CropTypeMappingSouthSudan(GeoFMDataset): +class CropTypeMappingSouthSudan(RawGeoFMDataset): def __init__( self, split: str, @@ -145,7 +145,7 @@ def __getitem__(self, idx): label = np.load(os.path.join(self.root_path, self.country, 'truth', f'{self.country}_{loc_id}.npz'))['truth'] label = self._mapping_label(label) - label = torch.from_numpy(label).float() + label = torch.from_numpy(label).long() metadata = self.get_metadata(idx) diff --git a/pangaea/datasets/fivebillionpixels.py b/pangaea/datasets/fivebillionpixels.py index 4cae19ef..f07f07a6 100644 --- a/pangaea/datasets/fivebillionpixels.py +++ b/pangaea/datasets/fivebillionpixels.py @@ -1,27 +1,16 @@ import os -import time -import torch -import numpy as np -import rasterio -import random from glob import glob -from PIL import Image +import numpy as np import tifffile as tiff -import cv2 - import torch import torchvision.transforms.functional as TF -import torchvision.transforms as T +from PIL import Image -import pathlib -import urllib -import tarfile -from pangaea.datasets.utils import DownloadProgressBar +from pangaea.datasets.base import RawGeoFMDataset -from pangaea.datasets.base import GeoFMDataset -class FiveBillionPixels(GeoFMDataset): +class FiveBillionPixels(RawGeoFMDataset): def __init__( self, split: str, @@ -55,10 +44,10 @@ def __init__( classes (list): classes of the dataset. num_classes (int): number of classes. ignore_index (int): index to ignore for metrics and loss. - img_size (int): size of the image. + img_size (int): size of the image. bands (dict[str, list[str]]): bands of the dataset. distribution (list[int]): class distribution. - data_mean (dict[str, list[str]]): mean for each band for each modality. + data_mean (dict[str, list[str]]): mean for each band for each modality. Dictionary with keys as the modality and values as the list of means. e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} data_std (dict[str, list[str]]): str for each band for each modality. @@ -112,41 +101,46 @@ def __init__( self.download_url = download_url self.auto_download = auto_download - self._image_dir = sorted(glob(os.path.join(self._base_dir, self.split, 'imgs', '*.tif'))) - self._label_dir = sorted(glob(os.path.join(self._base_dir, self.split, 'labels', '*.tif'))) + self._image_dir = sorted( + glob(os.path.join(self._base_dir, self.split, "imgs", "*.tif")) + ) + self._label_dir = sorted( + glob(os.path.join(self._base_dir, self.split, "labels", "*.tif")) + ) def __len__(self): return len(self._image_dir) def __getitem__(self, index): - if self.use_cmyk: - image = Image.open(self._image_dir[index]).convert('CMYK') + image = Image.open(self._image_dir[index]).convert("CMYK") image = TF.pil_to_tensor(image) else: - image = tiff.imread(self._image_dir[index])#.convert('CMYK') #check it also on the normalization + image = tiff.imread( + self._image_dir[index] + ) # .convert('CMYK') #check it also on the normalization image = image.astype(np.float32) # Convert to float32 image = torch.from_numpy(image).permute(2, 0, 1) - + + # output image shape (C T=1 H W) + image = image.unsqueeze(1) target = tiff.imread(self._label_dir[index]) target = target.astype(np.int64) # Convert to int64 (since it's a mask) target = torch.from_numpy(target).long() output = { - 'image': { - 'optical': image, + "image": { + "optical": image, }, - 'target': target, - 'metadata': {} + "target": target, + "metadata": {}, } - + return output - # @staticmethod # def get_splits(dataset_config): # dataset_train = FiveBillionPixels(dataset_config, split="train") # dataset_val = FiveBillionPixels(dataset_config, split="val") # dataset_test = FiveBillionPixels(dataset_config, split="test") # return dataset_train, dataset_val, dataset_test - diff --git a/pangaea/datasets/hlsburnscars.py b/pangaea/datasets/hlsburnscars.py index 0678660e..9ff87c2b 100644 --- a/pangaea/datasets/hlsburnscars.py +++ b/pangaea/datasets/hlsburnscars.py @@ -1,23 +1,21 @@ import os +import pathlib +import tarfile import time -import torch -import numpy as np -import tifffile as tiff -from typing import Sequence, Tuple -from sklearn.model_selection import train_test_split +import urllib from glob import glob +from typing import Sequence, Tuple +import numpy as np +import tifffile as tiff import torch +from sklearn.model_selection import train_test_split -import pathlib -import urllib -import tarfile - +from pangaea.datasets.base import RawGeoFMDataset from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset -class HLSBurnScars(GeoFMDataset): +class HLSBurnScars(RawGeoFMDataset): def __init__( self, split: str, @@ -38,7 +36,6 @@ def __init__( download_url: str, auto_download: bool, ): - """Initialize the HLSBurnScars dataset. Link: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars @@ -51,10 +48,10 @@ def __init__( classes (list): classes of the dataset. num_classes (int): number of classes. ignore_index (int): index to ignore for metrics and loss. - img_size (int): size of the image. + img_size (int): size of the image. bands (dict[str, list[str]]): bands of the dataset. distribution (list[int]): class distribution. - data_mean (dict[str, list[str]]): mean for each band for each modality. + data_mean (dict[str, list[str]]): mean for each band for each modality. Dictionary with keys as the modality and values as the list of means. e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} data_std (dict[str, list[str]]): str for each band for each modality. @@ -69,7 +66,7 @@ def __init__( download_url (str): url to download the dataset. auto_download (bool): whether to download the dataset automatically. """ - + super(HLSBurnScars, self).__init__( split=split, dataset_name=dataset_name, @@ -93,7 +90,7 @@ def __init__( self.root_path = root_path self.classes = classes self.split = split - + self.data_mean = data_mean self.data_std = data_std self.data_min = data_min @@ -106,10 +103,26 @@ def __init__( self.download_url = download_url self.auto_download = auto_download - self.split_mapping = {'train': 'training', 'val': 'validation', 'test': 'validation'} + self.split_mapping = { + "train": "training", + "val": "validation", + "test": "validation", + } - all_files = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*merged.tif'))) - all_targets = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*mask.tif'))) + all_files = sorted( + glob( + os.path.join( + self.root_path, self.split_mapping[self.split], "*merged.tif" + ) + ) + ) + all_targets = sorted( + glob( + os.path.join( + self.root_path, self.split_mapping[self.split], "*mask.tif" + ) + ) + ) if self.split != "test": split_indices = self.get_train_val_split(all_files) @@ -123,18 +136,17 @@ def __init__( self.image_list = all_files self.target_list = all_targets - @staticmethod def get_train_val_split(all_files) -> Tuple[Sequence[int], Sequence[int]]: - - # Fixed stratified sample to split data into train/val. - # This keeps 90% of datapoints belonging to an individual event in the training set and puts the remaining 10% in the validation set. - train_idxs, val_idxs = train_test_split(np.arange(len(all_files)), - test_size=0.1, - random_state=23, - ) + # Fixed stratified sample to split data into train/val. + # This keeps 90% of datapoints belonging to an individual event in the training set and puts the remaining 10% in the validation set. + train_idxs, val_idxs = train_test_split( + np.arange(len(all_files)), + test_size=0.1, + random_state=23, + ) return {"train": train_idxs, "val": val_idxs} - + def __len__(self): return len(self.image_list) @@ -150,17 +162,18 @@ def __getitem__(self, index): invalid_mask = image == 9999 image[invalid_mask] = 0 + # images must have (C T H W) shape + image = image.unsqueeze(1) output = { - 'image': { - 'optical': image, + "image": { + "optical": image, }, - 'target': target, - 'metadata': {} + "target": target, + "metadata": {}, } return output - @staticmethod def download(self, silent=False): output_path = pathlib.Path(self.root_path) @@ -170,7 +183,9 @@ def download(self, silent=False): os.makedirs(output_path, exist_ok=False) except FileExistsError: if not silent: - print("HLSBurnScars dataset folder exists, skipping downloading dataset.") + print( + "HLSBurnScars dataset folder exists, skipping downloading dataset." + ) return temp_file_name = f"temp_{hex(int(time.time()))}_hls_burn_scars.tar.gz" @@ -179,17 +194,20 @@ def download(self, silent=False): try: urllib.request.urlretrieve(url, output_path / temp_file_name, pbar) except urllib.error.HTTPError as e: - print('Error while downloading dataset: The server couldn\'t fulfill the request.') - print('Error code: ', e.code) + print( + "Error while downloading dataset: The server couldn't fulfill the request." + ) + print("Error code: ", e.code) return except urllib.error.URLError as e: - print('Error while downloading dataset: Failed to reach a server.') - print('Reason: ', e.reason) + print("Error while downloading dataset: Failed to reach a server.") + print("Reason: ", e.reason) return - with tarfile.open(output_path / temp_file_name, 'r') as tar: + with tarfile.open(output_path / temp_file_name, "r") as tar: print(f"Extracting to {output_path} ...") tar.extractall(output_path) print("done.") - os.remove(output_path / temp_file_name) \ No newline at end of file + os.remove(output_path / temp_file_name) + diff --git a/pangaea/datasets/mados.py b/pangaea/datasets/mados.py index bcc6f620..42ac6f11 100644 --- a/pangaea/datasets/mados.py +++ b/pangaea/datasets/mados.py @@ -1,6 +1,5 @@ import os -import time -import pathlib +import time import pathlib import urllib.request import urllib.error import zipfile @@ -18,14 +17,14 @@ import torchvision.transforms as T from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset ############################################################### # MADOS DATASET # ############################################################### # @DATASET_REGISTRY.register() -class MADOS(GeoFMDataset): +class MADOS(RawGeoFMDataset): def __init__( self, split: str, @@ -160,6 +159,8 @@ def __getitem__(self, index): invalid_mask = torch.isnan(image) image[invalid_mask] = 0 + # images must be of shape (C T H W) + image = image.unsqueeze(1) with rasterio.open(self.target_list[index], mode='r') as src: target = src.read(1) @@ -219,4 +220,4 @@ def download(self, silent=False): zip_ref.extractall(output_path, members) print("done.") - (output_path / temp_file_name).unlink() \ No newline at end of file + (output_path / temp_file_name).unlink() diff --git a/pangaea/datasets/pastis.py b/pangaea/datasets/pastis.py index 9abc3ac2..cd4471ae 100644 --- a/pangaea/datasets/pastis.py +++ b/pangaea/datasets/pastis.py @@ -14,7 +14,7 @@ import torch from einops import rearrange -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset def prepare_dates(date_dict, reference_date): @@ -64,7 +64,7 @@ def split_image(image_tensor, nb_split, id): ].float() -class Pastis(GeoFMDataset): +class Pastis(RawGeoFMDataset): def __init__( self, split: str, @@ -141,31 +141,15 @@ def __init__( folds = [4] else: folds = [5] - - self.dataset_name = dataset_name - self.bands = bands - self.split = split - self.path = root_path - self.data_mean = data_mean - self.data_std = data_std - self.data_min = data_min - self.data_max = data_max - self.classes = classes - self.img_size = img_size - self.distribution = distribution - - self.num_classes = num_classes - self.ignore_index = ignore_index - self.grid_size = multi_temporal - self.download_url = download_url - self.auto_download = auto_download self.modalities = ["s2", "aerial", "s1-asc"] self.nb_split = 1 reference_date = "2018-09-01" self.reference_date = datetime(*map(int, reference_date.split("-"))) - self.meta_patch = gpd.read_file(os.path.join(self.path, "metadata.geojson")) + self.meta_patch = gpd.read_file( + os.path.join(self.root_path, "metadata.geojson") + ) self.num_classes = 20 @@ -193,19 +177,16 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor part = i % (self.nb_split * self.nb_split) label = torch.from_numpy( np.load( - os.path.join(self.path, "ANNOTATIONS/TARGET_" + str(name) + ".npy") + os.path.join(self.root_path, "ANNOTATIONS/TARGET_" + str(name) + ".npy") )[0].astype(np.int32) ) - # remove void class - label[label == 19] = self.ignore_index - # label = label[1:-1] # remove Background and Void classes output = {"label": label, "name": name} for modality in self.modalities: if modality == "aerial": with rasterio.open( os.path.join( - self.path, + self.root_path, "DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_" + str(name) + ".tif", @@ -220,7 +201,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality_name.upper()), "{}_{}.npy".format(modality_name.upper(), name), ) @@ -237,7 +218,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality_name.upper()), "{}_{}.npy".format(modality_name.upper(), name), ) @@ -254,7 +235,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality_name.upper()), "{}_{}.npy".format(modality_name.upper(), name), ) @@ -286,7 +267,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality_name.upper()), "{}_{}.npy".format(modality_name.upper(), name), ) @@ -319,7 +300,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality_name.upper()), "{}_{}.npy".format(modality_name.upper(), name), ) @@ -337,7 +318,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor torch.from_numpy( np.load( os.path.join( - self.path, + self.root_path, "DATA_{}".format(modality.upper()), "{}_{}.npy".format(modality.upper(), name), ) @@ -360,17 +341,17 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor optical_ts = rearrange(output["s2"], "t c h w -> c t h w") sar_ts = rearrange(output["s1-asc"], "t c h w -> c t h w") - if self.grid_size == 1: + if self.multi_temporal == 1: # we only take the last frame optical_ts = optical_ts[:, -1] sar_ts = sar_ts[:, -1] else: # select evenly spaced samples optical_indexes = torch.linspace( - 0, optical_ts.shape[1] - 1, self.grid_size, dtype=torch.long + 0, optical_ts.shape[1] - 1, self.multi_temporal, dtype=torch.long ) sar_indexes = torch.linspace( - 0, sar_ts.shape[1] - 1, self.grid_size, dtype=torch.long + 0, sar_ts.shape[1] - 1, self.multi_temporal, dtype=torch.long ) optical_ts = optical_ts[:, optical_indexes] @@ -381,7 +362,7 @@ def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor "optical": optical_ts.to(torch.float32), "sar": sar_ts.to(torch.float32), }, - "target": output["label"], + "target": output["label"].to(torch.int64), "metadata": {}, } diff --git a/pangaea/datasets/sen1floods11.py b/pangaea/datasets/sen1floods11.py index 76680c7d..ab1b0db2 100644 --- a/pangaea/datasets/sen1floods11.py +++ b/pangaea/datasets/sen1floods11.py @@ -1,17 +1,18 @@ # Source: https://github.com/cloudtostreet/Sen1Floods11 import os + import geopandas import numpy as np import pandas as pd -import rasterio +import rasterio import torch +from pangaea.datasets.base import RawGeoFMDataset from pangaea.datasets.utils import download_bucket_concurrently -from pangaea.datasets.base import GeoFMDataset -class Sen1Floods11(GeoFMDataset): +class Sen1Floods11(RawGeoFMDataset): def __init__( self, split: str, @@ -31,7 +32,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, - gcs_bucket: str, + gcs_bucket: str, ): """Initialize the Sen1Floods11 dataset. Link: https://github.com/cloudtostreet/Sen1Floods11 @@ -45,10 +46,10 @@ def __init__( classes (list): classes of the dataset. num_classes (int): number of classes. ignore_index (int): index to ignore for metrics and loss. - img_size (int): size of the image. + img_size (int): size of the image. bands (dict[str, list[str]]): bands of the dataset. distribution (list[int]): class distribution. - data_mean (dict[str, list[str]]): mean for each band for each modality. + data_mean (dict[str, list[str]]): mean for each band for each modality. Dictionary with keys as the modality and values as the list of means. e.g. {"s2": [b1_mean, ..., bn_mean], "s1": [b1_mean, ..., bn_mean]} data_std (dict[str, list[str]]): str for each band for each modality. @@ -102,12 +103,20 @@ def __init__( self.ignore_index = ignore_index self.download_url = download_url self.auto_download = auto_download - - self.split_mapping = {'train': 'train', 'val': 'valid', 'test': 'test'} - split_file = os.path.join(self.root_path, "v1.1", f"splits/flood_handlabeled/flood_{self.split_mapping[split]}_data.csv") - metadata_file = os.path.join(self.root_path, "v1.1", "Sen1Floods11_Metadata.geojson") - data_root = os.path.join(self.root_path, "v1.1", "data/flood_events/HandLabeled/") + self.split_mapping = {"train": "train", "val": "valid", "test": "test"} + + split_file = os.path.join( + self.root_path, + "v1.1", + f"splits/flood_handlabeled/flood_{self.split_mapping[split]}_data.csv", + ) + metadata_file = os.path.join( + self.root_path, "v1.1", "Sen1Floods11_Metadata.geojson" + ) + data_root = os.path.join( + self.root_path, "v1.1", "data/flood_events/HandLabeled/" + ) self.metadata = geopandas.read_file(metadata_file) @@ -116,10 +125,16 @@ def __init__( file_list = [f.rstrip().split(",") for f in file_list] - self.s1_image_list = [os.path.join(data_root, 'S1Hand', f[0]) for f in file_list] - self.s2_image_list = [os.path.join(data_root, 'S2Hand', f[0].replace('S1Hand', 'S2Hand')) for f in file_list] - self.target_list = [os.path.join(data_root, 'LabelHand', f[1]) for f in file_list] - + self.s1_image_list = [ + os.path.join(data_root, "S1Hand", f[0]) for f in file_list + ] + self.s2_image_list = [ + os.path.join(data_root, "S2Hand", f[0].replace("S1Hand", "S2Hand")) + for f in file_list + ] + self.target_list = [ + os.path.join(data_root, "LabelHand", f[1]) for f in file_list + ] def __len__(self): return len(self.s1_image_list) @@ -130,7 +145,9 @@ def _get_date(self, index): if self.metadata[self.metadata["location"] == location].shape[0] != 1: date = pd.to_datetime("13-10-1998", dayfirst=True) else: - date = pd.to_datetime(self.metadata[self.metadata["location"] == location]["s2_date"].item()) + date = pd.to_datetime( + self.metadata[self.metadata["location"] == location]["s2_date"].item() + ) date_np = np.zeros((1, 3)) date_np[0, 0] = date.year date_np[0, 1] = date.dayofyear - 1 # base 0 @@ -148,31 +165,32 @@ def __getitem__(self, index): with rasterio.open(self.target_list[index]) as src: target = src.read(1) - + timestamp = self._get_date(index) s2_image = torch.from_numpy(s2_image).float() - s1_image = torch.from_numpy(s1_image).float() - target = torch.from_numpy(target) + s1_image = torch.from_numpy(s1_image).float() + target = torch.from_numpy(target).long() output = { - 'image': { - 'optical': s2_image, - 'sar' : s1_image, + "image": { + "optical": s2_image.unsqueeze(1), + "sar": s1_image.unsqueeze(1), }, - 'target': target, - 'metadata': { + "target": target, + "metadata": { "timestamp": timestamp, - } + }, } + return output @staticmethod def download(self, silent=False): if os.path.exists(self.root_path): if not silent: - print("Sen1Floods11 Dataset folder exists, skipping downloading dataset.") + print( + "Sen1Floods11 Dataset folder exists, skipping downloading dataset." + ) return download_bucket_concurrently(self.gcs_bucket, self.root_path) - - diff --git a/pangaea/datasets/spacenet7.py b/pangaea/datasets/spacenet7.py index 144eb4b3..557ddd7f 100644 --- a/pangaea/datasets/spacenet7.py +++ b/pangaea/datasets/spacenet7.py @@ -23,7 +23,7 @@ from abc import abstractmethod from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # train/val/test split from https://doi.org/10.3390/rs15215135 @@ -98,7 +98,7 @@ # SPACENET 7 DATASET # ############################################################### -class AbstractSN7(GeoFMDataset): +class AbstractSN7(RawGeoFMDataset): def __init__( self, @@ -562,7 +562,7 @@ def __getitem__(self, index): year_t1, month_t1 = timestamps[0]['year'], timestamps[0]['month'] year_t2, month_t2 = timestamps[-1]['year'], timestamps[-1]['month'] target = self.load_change_label(aoi_id, year_t1, month_t1, year_t2, month_t2) - target = torch.from_numpy(target) + target = torch.from_numpy(target).long() # cut to tile i, j = item['i'], item['j'] diff --git a/pangaea/datasets/utae_dynamicen.py b/pangaea/datasets/utae_dynamicen.py index 7705e331..8a5ef762 100644 --- a/pangaea/datasets/utae_dynamicen.py +++ b/pangaea/datasets/utae_dynamicen.py @@ -11,12 +11,12 @@ # import random # from PIL import Image -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # @DATASET_REGISTRY.register() -class DynamicEarthNet(GeoFMDataset): +class DynamicEarthNet(RawGeoFMDataset): def __init__( self, split: str, @@ -36,6 +36,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, + sample_dates: list[int] ): """Initialize the DynamicEarthNet dataset. Link: https://github.com/aysim/dynnet @@ -101,7 +102,7 @@ def __init__( self.download_url = download_url self.auto_download = auto_download - self.mode = 'weekly' + self.sample_dates = [str(d).rjust(2,'0') for d in sample_dates] self.files = [] @@ -110,6 +111,7 @@ def __init__( self.set_files() + def set_files(self): self.file_list = os.path.join(self.root_path, "dynnet_training_splits", f"{self.split}" + ".txt") @@ -118,97 +120,22 @@ def set_files(self): self.files, self.labels, self.year_months = list(zip(*file_list)) self.files = [f.replace('/reprocess-cropped/UTM-24000/', '/planet/') for f in self.files] - if self.mode == 'daily': - self.all_days = list(range(len(self.files))) - - for i in range(len(self.files)): - self.planet, self.day = [], [] - date_count = 0 - for _, _, infiles in os.walk(os.path.join(self.root_path, self.files[i][1:])): - for infile in sorted(infiles): - if infile.startswith(self.year_months[i]): - self.planet.append(os.path.join(self.files[i], infile)) - self.day.append((datetime(int(str(infile.split('.')[0])[:4]), int(str(infile.split('.')[0][5:7])), - int(str(infile.split('.')[0])[8:])) - self.reference_date).days) - date_count += 1 - self.all_days[i] = list(zip(self.planet, self.day)) - self.all_days[i].insert(0, date_count) - - else: - self.planet, self.day = [], [] - if self.mode == 'weekly': - self.dates = ['01', '05', '10', '15', '20', '25'] - elif self.mode == 'single': - self.dates = ['01'] - - for i, year_month in enumerate(self.year_months): - for date in self.dates: - curr_date = year_month + '-' + date - self.planet.append(os.path.join(self.files[i], curr_date + '.tif')) - self.day.append((datetime(int(str(curr_date)[:4]), int(str(curr_date[5:7])), - int(str(curr_date)[8:])) - self.reference_date).days) - self.planet_day = list(zip(*[iter(self.planet)] * len(self.dates), *[iter(self.day)] * len(self.dates))) - - - def load_data(self, index): - cur_images, cur_dates = [], [] - if self.mode == 'daily': - for i in range(1, self.all_days[index][0]+1): - img = rasterio.open(os.path.join(self.root_path, self.all_days[index][i][0][1:])) - red = img.read(3) - green = img.read(2) - blue = img.read(1) - nir = img.read(4) - image = np.dstack((red, green, blue, nir)) - cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ - cur_dates.append(self.all_days[index][i][1]) - - image_stack = np.concatenate(cur_images, axis=0) - dates = torch.from_numpy(np.array(cur_dates, dtype=np.int32)) - label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) - label = label.read() - mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) - - for i in range(self.num_classes + 1): - if i == 6: - mask[label[i, :, :] == 255] = -1 - else: - mask[label[i, :, :] == 255] = i - - return (image_stack, dates), mask - - else: - for i in range(len(self.dates)): - # read .tif - img = rasterio.open(os.path.join(self.root_path, self.planet_day[index][i][1:])) - red = img.read(3) - green = img.read(2) - blue = img.read(1) - nir = img.read(4) - image = np.dstack((red, green, blue, nir)) - cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ - image_stack = np.concatenate(cur_images, axis=0) - dates = torch.from_numpy(np.array(self.planet_day[index][len(self.dates):], dtype=np.int32)) - label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) - label = label.read() - mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) - - for i in range(self.num_classes + 1): - if i == 6: - mask[label[i, :, :] == 255] = -1 - else: - mask[label[i, :, :] == 255] = i - - return (image_stack, dates), mask + self.all_sequences = [] + for f, ym in zip(self.files, self.year_months): + images = [] + for date in self.sample_dates: + image_file = os.path.join(self.root_path, f[1:], f"{ym}-{date}.npy") + assert os.path.isfile(image_file), f"{image_file} does not exist" + images.append(image_file) + self.all_sequences.append(images) def __len__(self): return len(self.files) def __getitem__(self, index): - (images, dates), label = self.load_data(index) - - images = torch.from_numpy(images).permute(3, 0, 1, 2)#.transpose(0, 1) - label = torch.from_numpy(np.array(label, dtype=np.int32)).long() + images = [np.load(seq) for seq in self.all_sequences[index]] + images = torch.from_numpy(np.stack(images, axis=0)).transpose(0, 1).float() + label = torch.from_numpy(np.load(os.path.join(self.root_path, self.labels[index][1:].replace('tif', 'npy')))).long() output = { 'image': { @@ -219,14 +146,7 @@ def __getitem__(self, index): } return output - #return {'img': images, 'label': label, 'meta': dates} - - # @staticmethod - # def get_splits(dataset_config): - # dataset_train = DynamicEarthNet(cfg=dataset_config, split="train") - # dataset_val = DynamicEarthNet(cfg=dataset_config, split="val") - # dataset_test = DynamicEarthNet(cfg=dataset_config, split="test") - # return dataset_train, dataset_val, dataset_test + @staticmethod def download(self, silent=False): diff --git a/pangaea/datasets/xview2.py b/pangaea/datasets/xview2.py index 845f97c9..de88121e 100644 --- a/pangaea/datasets/xview2.py +++ b/pangaea/datasets/xview2.py @@ -15,10 +15,10 @@ import tarfile from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class xView2(GeoFMDataset): +class xView2(RawGeoFMDataset): def __init__( self, split: str, @@ -163,7 +163,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, Any, str]]: img = torch.from_numpy(img.transpose((3, 0, 1, 2))).float() # img_pre = torch.from_numpy(img_pre.transpose((2, 0, 1))).float() # img_post = torch.from_numpy(img_post.transpose((2, 0, 1))).float() - msk = torch.from_numpy(msk).float() + msk = torch.from_numpy(msk).long() return { diff --git a/pangaea/decoders/ltae.py b/pangaea/decoders/ltae.py index 71ab4950..f9674a80 100644 --- a/pangaea/decoders/ltae.py +++ b/pangaea/decoders/ltae.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn + class PositionalEncoder(nn.Module): def __init__(self, d, T=1000, repeat=None, offset=0): super(PositionalEncoder, self).__init__() @@ -36,6 +37,44 @@ def forward(self, batch_positions): +class LTAEChannelAdaptor(nn.Module): + def __init__(self, in_channels: list[int], out_channels: list[int]) -> None: + """LTAEChannelAdaptor for adapting the number of channels of the input features. + + Args: + in_channels (list[int]): list of the number of input channels + out_channels (list[int]): list of the number of output channels + """ + super(LTAEChannelAdaptor, self).__init__() + self.convs = nn.ModuleList( + [ + nn.Conv2d(in_c, out_c, 1) + for in_c, out_c in zip(in_channels, out_channels) + ] + ) + + def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + """Adapter for the number of channels of the input features. + + Args: + features (list[torch.Tensor]): list of features of shape (B C T H W) + from different layers of the encoder. + + Returns: + list[torch.Tensor]: list of adapted features of shape (B C' T H W) + """ + output = [] + for c, f in zip(self.convs, features): + # for all frames + adapted_feature = [] + # features of shape (B C T H W) + for t in range(f.shape[-3]): + adapted_feature.append(c(f[..., t, :, :])) + + output.append(torch.stack(adapted_feature, -3)) + return output + + class LTAE2d(nn.Module): def __init__( self, @@ -111,7 +150,7 @@ def __init__( self.mlp = nn.Sequential(*layers) self.dropout = nn.Dropout(dropout) - def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False): + def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False): sz_b, d, seq_len, h, w = x.shape x = x.permute(0, 2, 1, 3, 4) if pad_mask is not None: @@ -150,8 +189,10 @@ def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False): out = self.out_norm(out) if self.out_norm is not None else out out = out.view(sz_b, h, w, -1).permute(0, 3, 1, 2) - attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute( - 0, 1, 4, 2, 3 + attn = ( + attn.view(self.n_head, sz_b, h, w, seq_len) + .permute(0, 1, 4, 2, 3) + .contiguous() ) # head x b x t x h x w if self.return_att: diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index 7361a004..df74fef6 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from pangaea.decoders.base import Decoder -from pangaea.decoders.ltae import LTAE2d +from pangaea.decoders.ltae import LTAE2d, LTAEChannelAdaptor from pangaea.encoders.base import Encoder @@ -26,6 +26,7 @@ def __init__( channels: int, pool_scales=(1, 2, 3, 6), feature_multiplier: int = 1, + in_channels: list[int] | None = None, ): super().__init__( encoder=encoder, @@ -42,14 +43,32 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False + self.input_layers = self.encoder.output_layers + self.input_layers_num = len(self.input_layers) + + if in_channels is None: + self.in_channels = [ + dim * feature_multiplier for dim in self.encoder.output_dim + ] + else: + self.in_channels = [dim * feature_multiplier for dim in in_channels] + + if self.encoder.pyramid_output: + rescales = [1 for _ in range(self.input_layers_num)] + else: + scales = [4, 2, 1, 0.5] + rescales = [ + scales[int(i / self.input_layers_num * 4)] + for i in range(self.input_layers_num) + ] + self.neck = Feature2Pyramid( - embed_dim=encoder.output_dim * feature_multiplier, - rescales=[4, 2, 1, 0.5], + embed_dim=self.in_channels, + rescales=rescales, ) self.align_corners = False - self.in_channels = [encoder.output_dim * feature_multiplier for _ in range(4)] self.channels = channels self.num_classes = num_classes @@ -238,6 +257,9 @@ def __init__( pool_scales: list[int] = [1, 2, 3, 6], feature_multiplier: int = 1, ) -> None: + decoder_in_channels = self.get_decoder_in_channels( + multi_temporal_strategy, encoder + ) super().__init__( encoder=encoder, num_classes=num_classes, @@ -245,6 +267,7 @@ def __init__( channels=channels, pool_scales=pool_scales, feature_multiplier=feature_multiplier, + in_channels=decoder_in_channels, ) self.multi_temporal = multi_temporal @@ -256,17 +279,36 @@ def __init__( self.tmap = None else: if self.multi_temporal_strategy == "ltae": + ltae_in_channels = max(decoder_in_channels) + # if the encoder output channels vary we must use an adaptor before the LTAE + if decoder_in_channels != encoder.output_dim: + self.ltae_adaptor = LTAEChannelAdaptor( + in_channels=encoder.output_dim, + out_channels=decoder_in_channels, + ) + else: + self.ltae_adaptor = lambda x: x self.tmap = LTAE2d( positional_encoding=False, - in_channels=encoder.output_dim, - mlp=[encoder.output_dim, encoder.output_dim], - d_model=encoder.output_dim, + in_channels=ltae_in_channels, + mlp=[ltae_in_channels, ltae_in_channels], + d_model=ltae_in_channels, ) elif self.multi_temporal_strategy == "linear": self.tmap = nn.Linear(self.multi_temporal, 1) else: self.tmap = None + def get_decoder_in_channels( + self, multi_temporal_strategy: str | None, encoder: Encoder + ) -> list[int]: + if multi_temporal_strategy == "ltae": + # if the encoder output channels vary we must use an adaptor before the LTAE + ltae_in_channels = max(encoder.output_dim) + if ltae_in_channels != min(encoder.output_dim): + return [ltae_in_channels for _ in encoder.output_dim] + return encoder.output_dim + def forward( self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None ) -> torch.Tensor: @@ -308,14 +350,15 @@ def forward( ) feats = [list(i) for i in zip(*feats)] + # obtain features per layer feats = [torch.stack(feat_layers, dim=2) for feat_layers in feats] if self.tmap is not None: - for i in range(len(feats)): - if self.multi_temporal_strategy == "ltae": - feats[i] = self.tmap(feats[i]) - elif self.multi_temporal_strategy == "linear": - feats[i] = self.tmap(feats[i].permute(0, 1, 3, 4, 2)).squeeze(-1) + if self.multi_temporal_strategy == "ltae": + feats = self.ltae_adaptor(feats) + feats = [self.tmap(f) for f in feats] + elif self.multi_temporal_strategy == "linear": + feats = [self.tmap(f.permute(0, 1, 3, 4, 2)).squeeze(-1) for f in feats] feat = self.neck(feats) feat = self._forward_feature(feat) @@ -353,6 +396,8 @@ def __init__( else: raise NotImplementedError + encoder.enforce_single_temporal() + super().__init__( encoder=encoder, num_classes=num_classes, @@ -462,7 +507,13 @@ class RegUPerNet(Decoder): """ def __init__( - self, encoder: Encoder, finetune: bool, channels: int, pool_scales=(1, 2, 3, 6) + self, + encoder: Encoder, + finetune: bool, + channels: int, + pool_scales=(1, 2, 3, 6), + feature_multiplier: int = 1, + in_channels: list[int] | None = None, ): super().__init__( encoder=encoder, @@ -475,13 +526,29 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False - self.neck = Feature2Pyramid( - embed_dim=encoder.output_dim, rescales=[4, 2, 1, 0.5] - ) + self.input_layers = self.encoder.output_layers + self.input_layers_num = len(self.input_layers) + + if in_channels is None: + self.in_channels = [ + dim * feature_multiplier for dim in self.encoder.output_dim + ] + else: + self.in_channels = [dim * feature_multiplier for dim in in_channels] + + if self.encoder.pyramid_output: + rescales = [1 for _ in range(self.input_layers_num)] + else: + scales = [4, 2, 1, 0.5] + rescales = [ + scales[int(i / self.input_layers_num * 4)] + for i in range(self.input_layers_num) + ] + + self.neck = Feature2Pyramid(embed_dim=self.in_channels, rescales=rescales) self.align_corners = False - self.in_channels = [encoder.output_dim for _ in range(4)] self.channels = channels self.num_classes = 1 # regression @@ -620,29 +687,54 @@ def __init__( multi_temporal: bool | int, multi_temporal_strategy: str | None, pool_scales=(1, 2, 3, 6), + feature_multiplier: int = 1, ): + decoder_in_channels = self.get_decoder_in_channels( + multi_temporal_strategy, encoder + ) super().__init__( encoder=encoder, finetune=finetune, channels=channels, pool_scales=pool_scales, + feature_multiplier=feature_multiplier, + in_channels=decoder_in_channels, ) - + self.model_name = "Reg_MT_UPerNet" self.multi_temporal = multi_temporal self.multi_temporal_strategy = multi_temporal_strategy if self.multi_temporal_strategy == "ltae": + ltae_in_channels = max(decoder_in_channels) + # if the encoder output channels vary we must use an adaptor before the LTAE + if decoder_in_channels != encoder.output_dim: + self.ltae_adaptor = LTAEChannelAdaptor( + in_channels=encoder.output_dim, + out_channels=decoder_in_channels, + ) + else: + self.ltae_adaptor = lambda x: x self.tmap = LTAE2d( positional_encoding=False, - in_channels=encoder.output_dim, - mlp=[encoder.output_dim, encoder.output_dim], - d_model=encoder.output_dim, + in_channels=ltae_in_channels, + mlp=[ltae_in_channels, ltae_in_channels], + d_model=ltae_in_channels, ) elif self.multi_temporal_strategy == "linear": self.tmap = nn.Linear(self.multi_temporal, 1) else: self.tmap = None + def get_decoder_in_channels( + self, multi_temporal_strategy: str | None, encoder: Encoder + ) -> list[int]: + if multi_temporal_strategy == "ltae": + # if the encoder output channels vary we must use an adaptor before the LTAE + ltae_in_channels = max(encoder.output_dim) + if ltae_in_channels != min(encoder.output_dim): + return [ltae_in_channels for _ in encoder.output_dim] + return encoder.output_dim + def forward( self, img: dict[str, torch.Tensor], output_shape: torch.Size | None = None ) -> torch.Tensor: @@ -669,12 +761,12 @@ def forward( feats = [list(i) for i in zip(*feats)] feats = [torch.stack(feat_layers, dim=2) for feat_layers in feats] - if self.multi_temporal_strategy is not None: - for i in range(len(feats)): - if self.multi_temporal_strategy == "ltae": - feats[i] = self.tmap(feats[i]) - elif self.multi_temporal_strategy == "linear": - feats[i] = self.tmap(feats[i].permute(0, 1, 3, 4, 2)).squeeze(-1) + if self.tmap is not None: + if self.multi_temporal_strategy == "ltae": + feats = self.ltae_adaptor(feats) + feats = [self.tmap(f) for f in feats] + elif self.multi_temporal_strategy == "linear": + feats = [self.tmap(f.permute(0, 1, 3, 4, 2)).squeeze(-1) for f in feats] feat = self.neck(feats) feat = self._forward_feature(feat) @@ -745,57 +837,53 @@ class Feature2Pyramid(nn.Module): embed_dims (int): Embedding dimension. rescales (list[float]): Different sampling multiples were used to obtain pyramid features. Default: [4, 2, 1, 0.5]. - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='SyncBN', requires_grad=True). """ def __init__( self, embed_dim, - rescales=[4, 2, 1, 0.5], - norm_cfg=dict(type="SyncBN", requires_grad=True), + rescales=(4, 2, 1, 0.5), ): super().__init__() self.rescales = rescales self.upsample_4x = None - for k in self.rescales: + self.ops = nn.ModuleList() + + for i, k in enumerate(self.rescales): if k == 4: - self.upsample_4x = nn.Sequential( - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), - nn.SyncBatchNorm(embed_dim), - nn.GELU(), - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + self.ops.append( + nn.Sequential( + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ), + nn.SyncBatchNorm(embed_dim[i]), + nn.GELU(), + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ), + ) ) elif k == 2: - self.upsample_2x = nn.Sequential( - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2) + self.ops.append( + nn.Sequential( + nn.ConvTranspose2d( + embed_dim[i], embed_dim[i], kernel_size=2, stride=2 + ) + ) ) elif k == 1: - self.identity = nn.Identity() + self.ops.append(nn.Identity()) elif k == 0.5: - self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + self.ops.append(nn.MaxPool2d(kernel_size=2, stride=2)) elif k == 0.25: - self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + self.ops.append(nn.MaxPool2d(kernel_size=4, stride=4)) else: raise KeyError(f"invalid {k} for feature2pyramid") def forward(self, inputs): assert len(inputs) == len(self.rescales) outputs = [] - if self.upsample_4x is not None: - ops = [ - self.upsample_4x, - self.upsample_2x, - self.identity, - self.downsample_2x, - ] - else: - ops = [ - self.upsample_2x, - self.identity, - self.downsample_2x, - self.downsample_4x, - ] + for i in range(len(inputs)): - outputs.append(ops[i](inputs[i])) + outputs.append(self.ops[i](inputs[i])) return tuple(outputs) diff --git a/pangaea/encoders/base.py b/pangaea/encoders/base.py index 382c898a..f3c95e44 100644 --- a/pangaea/encoders/base.py +++ b/pangaea/encoders/base.py @@ -57,9 +57,11 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, embed_dim: int, - output_dim: int, + output_layers: list[int], + output_dim: int | list[int], multi_temporal: bool, multi_temporal_output: bool, + pyramid_output: bool, encoder_weights: str | Path, download_url: str, ) -> None: @@ -82,10 +84,17 @@ def __init__( self.input_bands = input_bands self.input_size = input_size self.embed_dim = embed_dim - self.output_dim = output_dim + self.output_layers = output_layers + self.output_dim = ( + [output_dim for _ in output_layers] + if isinstance(output_dim, int) + else list(output_dim) + ) self.encoder_weights = encoder_weights self.multi_temporal = multi_temporal self.multi_temporal_output = multi_temporal_output + + self.pyramid_output = pyramid_output self.download_url = download_url # download_model if necessary @@ -102,6 +111,11 @@ def load_encoder_weights(self, logger: Logger) -> None: """ raise NotImplementedError + def enforce_single_temporal(self): + return + # self.multi_temporal = False + # self.multi_temporal_fusion = False + def parameters_warning( self, missing: dict[str, torch.Size], diff --git a/pangaea/encoders/croma_encoder.py b/pangaea/encoders/croma_encoder.py index dfea553d..4148a053 100644 --- a/pangaea/encoders/croma_encoder.py +++ b/pangaea/encoders/croma_encoder.py @@ -43,6 +43,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, size="base", ): @@ -52,9 +53,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -73,7 +76,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s2_channels = 12 # fixed at 12 multispectral optical channels self.attn_bias = get_2dalibi( @@ -156,6 +158,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, size="base", ): @@ -165,9 +168,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -186,7 +191,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s1_channels = 2 # fixed at 2 SAR backscatter channels @@ -277,8 +281,11 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, + size="base", + ): super().__init__( model_name="croma_joint", @@ -286,9 +293,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -307,7 +316,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s1_channels = 2 # fixed at 2 SAR backscatter channels diff --git a/pangaea/encoders/dofa_encoder.py b/pangaea/encoders/dofa_encoder.py index a5376b37..e636a32f 100644 --- a/pangaea/encoders/dofa_encoder.py +++ b/pangaea/encoders/dofa_encoder.py @@ -180,6 +180,7 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, embed_dim: int, + output_dim: int | list[int], output_layers: int | list[int], wave_list: dict[str, dict[str, float]], download_url: str, @@ -198,9 +199,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -216,9 +219,7 @@ def __init__( self.wave_list[m][bi] for m, b in self.input_bands.items() for bi in b ] - self.norm = norm_layer( - [embed_dim, (self.img_size // patch_size), (self.img_size // patch_size)] - ) + self.norm = norm_layer(self.embed_dim) self.patch_embed = Dynamic_MLP_OFA( wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim @@ -261,6 +262,8 @@ def forward(self, image): output = [] for i, blk in enumerate(self.blocks): x = blk(x) + if i == len(self.blocks) - 1: + x = self.norm(x) if i in self.output_layers: out = ( x[:, 1:] @@ -273,8 +276,7 @@ def forward(self, image): ) .contiguous() ) - if self.use_norm: - out = self.norm(out) + output.append(out) return output diff --git a/pangaea/encoders/gfmswin_encoder.py b/pangaea/encoders/gfmswin_encoder.py index bccbbb8d..49b640a2 100644 --- a/pangaea/encoders/gfmswin_encoder.py +++ b/pangaea/encoders/gfmswin_encoder.py @@ -645,8 +645,10 @@ def __init__( input_size=input_size, embed_dim=embed_dim, output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) @@ -662,7 +664,6 @@ def __init__( self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # self.use_norm = use_norm - self.out_dim = output_dim self.only_output_last = only_output_last # split image into non-overlapping patches self.patch_embed = PatchEmbed( @@ -767,16 +768,9 @@ def forward(self, image): for i, layer in enumerate(self.layers): x = layer(x) B, L, C = x.shape - if not self.only_output_last: - if i in self.output_layers: - out = x.reshape(B, C, int(L**0.5), int(L**0.5)).repeat( - 1, self.out_dim // C, 1, 1 - ) - out = out.view(B, self.out_dim, int(L**0.5), int(L**0.5)) - output.append(out) - else: - if i == self.num_layers - 1: - output.extend([x.reshape(B, C, int(L**0.5), int(L**0.5))] * 4) + if i in self.output_layers: + out = x.transpose(1, 2).view(B, C, int(L**0.5), int(L**0.5)) + output.append(out) # if self.use_norm: # x = self.norm(x) diff --git a/pangaea/encoders/prithvi_encoder.py b/pangaea/encoders/prithvi_encoder.py index e491c1e6..7a380638 100644 --- a/pangaea/encoders/prithvi_encoder.py +++ b/pangaea/encoders/prithvi_encoder.py @@ -45,6 +45,7 @@ def __init__( encoder_weights: str | Path, input_bands: dict[str, list[str]], input_size: int, + output_dim: int | list[int], output_layers: int | list[int], download_url: str, patch_size=16, @@ -63,15 +64,18 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=True, multi_temporal_output=True, + pyramid_output=False, download_url=download_url, ) self.output_layers = output_layers self.img_size = self.input_size + self.tublet_size = tubelet_size if num_frames: self.num_frames = num_frames @@ -193,6 +197,32 @@ def forward(self, image): return output + def enforce_single_temporal(self): + + self.num_frames = 1 + + self.patch_embed = PatchEmbed( + self.input_size, + self.patch_size, + 1, + self.tublet_size, + self.in_chans, + self.embed_dim, + ) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dim), requires_grad=False + ) + + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + class PatchEmbed(nn.Module): """Frames of 2D Images to Patch Embedding The 3D version of timm.models.vision_transformer.PatchEmbed diff --git a/pangaea/encoders/remoteclip_encoder.py b/pangaea/encoders/remoteclip_encoder.py index 9ff5a7ee..a904d1a0 100644 --- a/pangaea/encoders/remoteclip_encoder.py +++ b/pangaea/encoders/remoteclip_encoder.py @@ -351,6 +351,7 @@ def __init__( layers: int, mlp_ratio: float, output_layers: int | list[int], + output_dim: int | list[int], download_url: str, ls_init_value: float | None = None, patch_dropout: float = 0.0, @@ -365,9 +366,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/satlasnet_encoder.py b/pangaea/encoders/satlasnet_encoder.py index 9012cf2b..dc09668c 100644 --- a/pangaea/encoders/satlasnet_encoder.py +++ b/pangaea/encoders/satlasnet_encoder.py @@ -418,7 +418,8 @@ def __init__( self, input_bands: dict[str, list[str]], input_size: int, - output_dim: int, + output_dim: int | list[int], + output_layers: int | list[int], model_identifier: str, encoder_weights: str | Path, download_url: str, @@ -433,9 +434,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, # will be overwritten by the backbone + output_layers=output_layers, output_dim=output_dim, multi_temporal=False if "_SI_" in model_identifier else True, multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) @@ -450,7 +453,6 @@ def __init__( self.multi_image = model_info["multi_image"] self.backbone_arch = model_info["backbone"] - self.out_dim = self.output_dim self.backbone = self._initialize_backbone( self.in_chans, self.backbone_arch, self.multi_image @@ -542,9 +544,10 @@ def load_encoder_weights(self, logger: Logger) -> None: def forward(self, imgs): # Define forward pass + if not isinstance(self.backbone, AggregationBackbone): + imgs["optical"] = imgs["optical"].squeeze(2) - x = imgs["optical"].squeeze(2) - x = self.backbone(x) + x = self.backbone(imgs["optical"]) if self.fpn: x = self.fpn(x) @@ -552,8 +555,6 @@ def forward(self, imgs): output = [] for i in range(len(x)): - B, C, H, W = x[i].shape - out = x[i].repeat(1, self.out_dim // C, 1, 1) - output.append(out) + output.append(x[i]) return output diff --git a/pangaea/encoders/scalemae_encoder.py b/pangaea/encoders/scalemae_encoder.py index 742f0564..3dbe3b44 100644 --- a/pangaea/encoders/scalemae_encoder.py +++ b/pangaea/encoders/scalemae_encoder.py @@ -54,6 +54,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -71,9 +72,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/spectralgpt_encoder.py b/pangaea/encoders/spectralgpt_encoder.py index 98a09a7c..0aea1e9e 100644 --- a/pangaea/encoders/spectralgpt_encoder.py +++ b/pangaea/encoders/spectralgpt_encoder.py @@ -45,11 +45,11 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, t_patch_size: int = 3, patch_size: int = 16, - output_dim: int = 768, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -66,9 +66,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, + output_layers=output_layers, output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -198,30 +200,24 @@ def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]: pos_embed = self.pos_embed[:, :, :] x = x + pos_embed - # reshape to [N, T, L, C] or [N, T*L, C] - requires_t_shape = ( - len(self.blocks) > 0 # support empty decoder - and hasattr(self.blocks[0].attn, "requires_t_shape") - and self.blocks[0].attn.requires_t_shape - ) - if requires_t_shape: - x = x.view([N, T, L, C]) - output = [] for i, blk in enumerate(self.blocks): x = blk(x) if i in self.output_layers: if self.cls_embed: - x = x[:, 1:] + out = x[:, 1:] + else: + out = x + out = out.view(N, T, L, C).transpose(2, 3).flatten(1, 2) out = ( - x.permute(0, 2, 1) + out.permute(0, 2, 1) + .contiguous() .view( x.shape[0], -1, self.input_size // self.patch_size, self.input_size // self.patch_size, ) - .contiguous() ) output.append(out) diff --git a/pangaea/encoders/ssl4eo_data2vec_encoder.py b/pangaea/encoders/ssl4eo_data2vec_encoder.py index 5de04436..28ea33cd 100644 --- a/pangaea/encoders/ssl4eo_data2vec_encoder.py +++ b/pangaea/encoders/ssl4eo_data2vec_encoder.py @@ -359,6 +359,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, patch_size: int = 16, @@ -386,9 +387,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/ssl4eo_dino_encoder.py b/pangaea/encoders/ssl4eo_dino_encoder.py index acc16668..308d8bdc 100644 --- a/pangaea/encoders/ssl4eo_dino_encoder.py +++ b/pangaea/encoders/ssl4eo_dino_encoder.py @@ -243,6 +243,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, patch_size: int = 16, @@ -265,9 +266,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/ssl4eo_mae_encoder.py b/pangaea/encoders/ssl4eo_mae_encoder.py index ae4584af..68bbc1b5 100644 --- a/pangaea/encoders/ssl4eo_mae_encoder.py +++ b/pangaea/encoders/ssl4eo_mae_encoder.py @@ -46,6 +46,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -61,9 +62,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -198,6 +201,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -212,6 +216,7 @@ def __init__( input_bands=input_bands, input_size=input_size, output_layers=output_layers, + output_dim=output_dim, embed_dim=embed_dim, patch_size=patch_size, in_chans=in_chans, @@ -223,8 +228,7 @@ def __init__( ) self.model_name = "ssl4eo_mae_sar" - self.multi_temporal = False - self.output_dim = embed_dim + def forward(self, image): # embed patches diff --git a/pangaea/encoders/ssl4eo_moco_encoder.py b/pangaea/encoders/ssl4eo_moco_encoder.py index 9f42ef6c..7c74635e 100644 --- a/pangaea/encoders/ssl4eo_moco_encoder.py +++ b/pangaea/encoders/ssl4eo_moco_encoder.py @@ -42,6 +42,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -60,9 +61,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/unet_encoder.py b/pangaea/encoders/unet_encoder.py index a7089e7e..8281db6f 100644 --- a/pangaea/encoders/unet_encoder.py +++ b/pangaea/encoders/unet_encoder.py @@ -25,6 +25,7 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, topology: Sequence[int], + output_dim: int | list[int], download_url: str, encoder_weights: str | None = None, ): @@ -34,12 +35,15 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=0, - output_dim=0, + output_dim=output_dim, + output_layers=None, multi_temporal=False, # single time frame multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) + # TODO: now only supports optical bands for single time frame self.in_channels = len(input_bands["optical"]) # number of optical bands self.topology = topology @@ -88,7 +92,6 @@ def __init__(self, topology: Sequence[int]): def forward(self, x1: torch.Tensor) -> list: inputs = [x1] - # Downward U: for layer in self.down_seq.values(): out = layer(inputs[-1]) inputs.append(out) diff --git a/pangaea/engine/data_preprocessor.py b/pangaea/engine/data_preprocessor.py index f8b5baa6..ec6b5ebf 100644 --- a/pangaea/engine/data_preprocessor.py +++ b/pangaea/engine/data_preprocessor.py @@ -1,324 +1,105 @@ import logging import math import random +import numbers import numpy as np import torch import torchvision.transforms as T +import torchvision.transforms.functional as TF from torch.utils.data import Dataset -from pangaea.datasets.base import GeoFMDataset -from pangaea.encoders.base import Encoder +from hydra.utils import instantiate +from typing import Callable, Dict, List, Optional, Sequence, Union, Tuple +import copy -class RichDataset(Dataset): - """Dataset wrapper to add preprocessing steps.""" - def __init__(self, dataset: GeoFMDataset, encoder: Encoder): - """Initialize the RichDataset. - - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - self.dataset = dataset - self.encoder = encoder - - # WARNING: Patch to overcome recursive wrapping issues - self.split = dataset.split - self.dataset_name = dataset.dataset_name - self.multi_modal = dataset.multi_modal - self.multi_temporal = dataset.multi_temporal - self.root_path = dataset.root_path - self.classes = dataset.classes - self.num_classes = dataset.num_classes - self.ignore_index = dataset.ignore_index - self.img_size = dataset.img_size - self.bands = dataset.bands - self.distribution = dataset.distribution - self.data_mean = dataset.data_mean - self.data_std = dataset.data_std - self.data_min = dataset.data_min - self.data_max = dataset.data_max - self.download_url = dataset.download_url - self.auto_download = dataset.auto_download - - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - "optical": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset), - "sar": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - return self.dataset[index] - - def __len__(self) -> int: - """Return the length of the dataset. - - Returns: - int: length of the dataset. - """ - return len(self.dataset) +class BasePreprocessor(): + """Base class for preprocessor.""" + def __init__(self,) -> None: + return -class SegPreprocessor(RichDataset): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the SegPreprocessor for segmentation tasks. - - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - super().__init__(dataset, encoder) - - self.preprocessor = {} - self.preprocessor["optical"] = ( - BandAdaptor(dataset=dataset, encoder=encoder, modality="optical") - if "optical" in dataset.bands.keys() - else None - ) - self.preprocessor["sar"] = ( - BandAdaptor(dataset=dataset, encoder=encoder, modality="sar") - if "sar" in dataset.bands.keys() - else None - ) - for modality in self.encoder.input_bands: - new_stats = self.preprocessor[modality].preprocess_band_statistics( - self.dataset.data_mean[modality], - self.dataset.data_std[modality], - self.dataset.data_min[modality], - self.dataset.data_max[modality], - ) - - self.dataset.data_mean[modality] = new_stats[0] - self.dataset.data_std[modality] = new_stats[1] - self.dataset.data_min[modality] = new_stats[2] - self.dataset.data_max[modality] = new_stats[3] - - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = self.preprocessor[k](v) - - data["target"] = data["target"].long() - return data + raise NotImplementedError + def update_meta(self, meta): + raise NotImplementedError -class RegPreprocessor(SegPreprocessor): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the RegPreprocessor for regression tasks.""" - super().__init__(dataset, encoder) - - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] + def check_dimension(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]): + """check dimension (C, T, H, W) of data""" for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = self.preprocessor[k](v) - data["target"] = data["target"].float() - return data + if len(v.shape) != 4: + raise AssertionError(f"Image dimension must be 4 (C, T, H, W), Got {str(len(v.shape))}") + if len(data["target"].shape) != 2: + raise AssertionError(f"Target dimension must be 2 (H, W), Got {str(len(data['target'].shape))}") -class BandAdaptor: - def __init__(self, dataset: GeoFMDataset, encoder: Encoder, modality: str) -> None: - """Intialize the BandAdaptor. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder unded. - modality (str): image modality. - """ - self.dataset_bands = dataset.bands[modality] - self.input_bands = getattr(encoder.input_bands, modality, []) - - # list of length dataset_n_bands with True if the band is used in the encoder - # and is available in the dataset - self.used_bands_mask = torch.tensor( - [b in self.input_bands for b in self.dataset_bands], dtype=torch.bool - ) - # list of length encoder_n_bands with True if the band is available in the dataset - # and used in the encoder - self.avail_bands_mask = torch.tensor( - [b in self.dataset_bands for b in self.input_bands], dtype=torch.bool - ) - # list of length encoder_n_bands with the index of the band in the dataset - # if the band is available in the dataset and -1 otherwise - self.avail_bands_indices = torch.tensor( - [ - self.dataset_bands.index(b) if b in self.dataset_bands else -1 - for b in self.input_bands - ], - dtype=torch.long, - ) - - # if the encoder requires bands that are not available in the dataset - # then we need to pad the input with zeros - self.need_padded = self.avail_bands_mask.sum() < len(self.input_bands) - self.logger = logging.getLogger() - self.logger.info(f"Adaptor for modality: {modality}") - self.logger.info( - "Available bands in dataset: {}".format( - " ".join(str(b) for b in self.dataset_bands) - ) - ) - self.logger.info( - "Required bands in encoder: {}".format( - " ".join(str(b) for b in self.input_bands) - ) - ) - if self.need_padded: - self.logger.info( - "Unavailable bands {} are padded with zeros".format( - " ".join( - str(b) - for b in np.array(self.input_bands)[ - self.avail_bands_mask.logical_not() - ] - ) - ) - ) + def check_size(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]): + """check if data size is equal""" + base_shape = data["image"][list(data["image"].keys())[0]].shape - def preprocess_band_statistics( - self, - data_mean: list[float], - data_std: list[float], - data_min: list[float], - data_max: list[float], - ) -> tuple[ - list[float], - list[float], - list[float], - list[float], - ]: - """Filter the statistics to match the available bands. - Args: - data_mean (list[float]): dataset mean (per band in dataset). - data_std (list[float]): dataset std (per band in dataset). - data_min (list[float]): dataset min (per band in dataset). - data_max (list[float]): dataset max (per band in dataset). - Returns: - tuple[ list[float], list[float], list[float], list[float], ]: - dataset mean, std, min, max (per band in encoder). Pad with zeros - if the band is required by the encoder but not included in the dataset. - """ - data_mean = [ - data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist() - ] - data_std = [ - data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] - data_min = [ - data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist() - ] - data_max = [ - data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] - return data_mean, data_std, data_min, data_max - - def preprocess_single_timeframe(self, image: torch.Tensor) -> torch.Tensor: - """Apply the preprocessing to a single timeframe, i.e. pad unavailable - bands with zeros if needed to match encoder's bands. - Args: - image (torch.Tensor): input image of shape (dataset_n_bands H W). - Returns: - torch.Tensor: output image of shape (encoder_n_bands H W). - """ - # add padding band at index 0 on the first dim - padded_image = torch.cat([torch.zeros_like(image[0:1]), image], dim=0) - # request all encoder's band. In self.avail_band_indices we have - # -1 for bands not available in the dataset. So we add 1 to get the - # correct index in the padded image (index 0 is the 0-padding band) - return padded_image[self.avail_bands_indices + 1] - - def __call__(self, image: torch.Tensor) -> torch.Tensor: - """Apply the preprocessing to the image. Pad unavailable bands with zeros. - Args: - image (torch.Tensor): image of shape (dataset_n_bands H W). - Returns: - torch.Tensor: output image of shape (encoder_n_bands T H W). - In the case of sigle timeframe, T = 1. - """ - # input of shape (dataset_n_bands T H W) output of shape (encoder_n_bands T H W) - if len(image.shape) == 3: - # Add a time dimension so preprocessing can work on consistent images - image = image.unsqueeze( - 1 - ) # (dataset_n_bands H W)-> (dataset_n_bands 1 H W) - - if image.shape[1] != 1: - final_image = [] - for i in range(image.shape[1]): - final_image.append(self.preprocess_single_timeframe(image[:, i, :, :])) - image = torch.stack(final_image, dim=1) - else: - image = self.preprocess_single_timeframe(image) - - # OUTPUT SHAPE (encoder_n_bands T H W) (T = 1 in the case of single timeframe) - return image + for k, v in data["image"].items(): + if v.shape[1:] != base_shape[1:]: + shape = {k: tuple(v.shape[1:]) for k, v in data["image"].items()} + raise AssertionError(f"Image size (T, H, W) from all modalities must be equal, Got {str(shape)}") + if base_shape[-2:] != data["target"].shape[-2:]: + raise AssertionError(f"Image size and target size (H, W) must be equal, Got {str(tuple(base_shape[-2:]))} and {str(tuple(data['target'].shape[-2:]))}") -class BaseAugment(RichDataset): - """Base class for augmentations.""" - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Augment item. Should call the dataset __getitem__ method which output a dictionary - with the keys "image", "target" and "metadata": - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. +class Preprocessor(BasePreprocessor): + """A series of base preprocessors that preprocess images and targets.""" + def __init__( + self, + preprocessor_cfg, + dataset_cfg, + encoder_cfg + ) -> None: + """Build preprocessors defined in preprocessor_cfg. + Args: + preprocessor_cfg: preprocessor config + dataset_cfg: dataset config + encoder_cfg: encoder config + """ + super().__init__() + # initialize the meta statistics/info of the input data and target encoder + meta = {} + meta['dataset_img_size'] = dataset_cfg['img_size'] + meta['encoder_input_size'] = encoder_cfg['input_size'] + meta['dataset_bands'] = dataset_cfg['bands'] + meta['encoder_bands'] = encoder_cfg['input_bands'] + meta['multi_modal'] = dataset_cfg['multi_modal'] + meta['multi_temporal'] = dataset_cfg['multi_temporal'] + + meta['data_bands'] = dataset_cfg['bands'] + meta['data_img_size'] = dataset_cfg['img_size'] + meta['data_mean'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_mean'].items()} + meta['data_std'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_std'].items()} + meta['data_min'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_min'].items()} + meta['data_max'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_max'].items()} + + meta['ignore_index'] = dataset_cfg['ignore_index'] + meta['class_distribution'] = torch.tensor(dataset_cfg['distribution']) + + self.preprocessor = [] + + # build the preprocessor and update the meta for the next + for preprocess in preprocessor_cfg: + preprocessor = instantiate(preprocess, **meta) + meta = preprocessor.update_meta(meta) + self.preprocessor.append(preprocessor) + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """preprocess images and targets step by step. Args: - index (int): index of data. - + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -330,57 +111,45 @@ def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: "target": torch.Tensor of shape (H W), "metadata": dict}. """ - super().__init__(dataset, encoder) + self.check_dimension(data) + for process in self.preprocessor: + data = process(data) + return data + + +class BandFilter(BasePreprocessor): -class Tile(BaseAugment): def __init__( - self, dataset: GeoFMDataset, encoder: Encoder, min_overlap: int = 0 + self, + **meta ) -> None: - """Initialize the Tiling augmentation. - + """Intialize the BandFilter. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - min_overlap (int, optional): minimum overlap between tiles. Defaults to 0. + meta: statistics/info of the input data and target encoder + data_bands: bands of incoming data + encoder_bands: expected bands by encoder """ - super().__init__(dataset, encoder) - self.min_overlap = min_overlap - # Should be the _largest_ image in the dataset to avoid problems mentioned in __getitem__ - self.input_size = self.dataset.img_size - self.output_size = self.encoder.input_size - if self.output_size == self.input_size: - self.tiles_per_dim = 1 - elif self.output_size > self.input_size: - raise ValueError( - f"Can't tile inputs if dataset.img_size={self.input_size} < encoder.input_size={self.output_size}, use ResizeToEncoder instead." - ) - elif self.min_overlap >= self.input_size: - raise ValueError("min_overlap >= dataset.img_size") - elif self.min_overlap >= self.input_size: - raise ValueError("min_overlap >= encoder.input_size") - else: - self.tiles_per_dim = math.ceil( - (self.input_size - self.min_overlap) - / (self.output_size - self.min_overlap) - ) + super().__init__() - logging.getLogger().info( - f"Tiling {self.input_size}x{self.input_size} input images to {self.tiles_per_dim * self.tiles_per_dim} {self.output_size}x{self.output_size} output images." - ) + self.used_bands_indices = {} - self.h_spacing_cache = [None] * super().__len__() - self.w_spacing_cache = [None] * super().__len__() + for k in meta['data_bands'].keys(): + if k not in meta['encoder_bands'].keys(): + continue + self.used_bands_indices[k] = torch.tensor( + [meta['data_bands'][k].index(b) for b in meta['encoder_bands'][k] if b in meta['data_bands'][k]], dtype=torch.long + ) - self.data_cache = (None, None) + if not self.used_bands_indices: + raise ValueError("No nontrivial input bands after BandFilter!") - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Tiling to the data. + """Filter redundant bands from the data. Args: - index (int): index of data. - + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -392,155 +161,119 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - if self.tiles_per_dim == 1: - return self.dataset[index] - - dataset_index = math.floor(index / (self.tiles_per_dim * self.tiles_per_dim)) - data = self.dataset[dataset_index] - # Calculate tile coordinates - tile_index = index % (self.tiles_per_dim * self.tiles_per_dim) - h_index = math.floor(tile_index / self.tiles_per_dim) - w_index = tile_index % self.tiles_per_dim - # Use the actual image size so we can handle data that's not always uniform. - # This means that min_overlap might not always be respected. - # Also, in case there was insufficient overlap (or tiles_per_dim=1) sepcified, we'll crop the image and lose info. - input_h, input_w = data["image"][next(iter(data["image"].keys()))].shape[-2:] - - # Calculate the sizes of the labeled parts seperately to deal with aliasing when - # tile spacing values are not exact integers - if not self.h_spacing_cache[dataset_index]: - float_spacing = np.linspace( - 0, input_h - self.output_size, self.tiles_per_dim - ) - rounded_spacing = float_spacing.round().astype(int) - labeled_sizes = np.ediff1d(rounded_spacing, to_end=self.output_size) - self.h_spacing_cache[dataset_index] = (rounded_spacing, labeled_sizes) - if not self.w_spacing_cache[dataset_index]: - float_spacing = np.linspace( - 0, input_w - self.output_size, self.tiles_per_dim - ) - rounded_spacing = float_spacing.round().astype(int) - labeled_sizes = np.ediff1d(rounded_spacing, to_end=self.output_size) - self.w_spacing_cache[dataset_index] = (rounded_spacing, labeled_sizes) - h_positions, h_labeled_sizes = self.h_spacing_cache[dataset_index] - w_positions, w_labeled_sizes = self.w_spacing_cache[dataset_index] + data["image"] = {k: data["image"][k][v] for k, v in self.used_bands_indices.items()} - h, w = h_positions[h_index], w_positions[w_index] - h_labeled, w_labeled = h_labeled_sizes[h_index], w_labeled_sizes[w_index] + return data - tiled_data = {"image": {}, "target": None} - tiled_data["image"] = {} - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - tiled_data["image"][k] = v[ - ..., h : h + self.output_size, w : w + self.output_size - ].clone() - - # Place the mesaured part in the middle to help with tiling artefacts - h_label_offset = round((self.output_size - h_labeled) / 2) - w_label_offset = round((self.output_size - w_labeled) / 2) - - # Crop target to size - tiled_data["target"] = data["target"][ - ..., h : h + self.output_size, w : w + self.output_size - ].clone() - - # Ignore overlapping borders - if h_index != 0: - tiled_data["target"][..., 0:h_label_offset, :] = self.dataset.ignore_index - if w_index != 0: - tiled_data["target"][..., 0:w_label_offset] = self.dataset.ignore_index - if h_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - h_label_offset :, :] = ( - self.dataset.ignore_index - ) - if w_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - w_label_offset :] = ( - self.dataset.ignore_index - ) + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + for k in list(meta['data_bands'].keys()): + if k not in self.used_bands_indices.keys(): + meta['data_bands'].pop(k, None) + meta['data_mean'].pop(k, None) + meta['data_std'].pop(k, None) + meta['data_min'].pop(k, None) + meta['data_max'].pop(k, None) + else: + meta['data_bands'][k] = [meta['data_bands'][k][i.item()] for i in self.used_bands_indices[k]] + meta['data_mean'][k] = meta['data_mean'][k][self.used_bands_indices[k]] + meta['data_std'][k] = meta['data_std'][k][self.used_bands_indices[k]] + meta['data_min'][k] = meta['data_min'][k][self.used_bands_indices[k]] + meta['data_max'][k] = meta['data_max'][k][self.used_bands_indices[k]] - return tiled_data + return meta - def __len__(self): - return (super().__len__()) * (self.tiles_per_dim * self.tiles_per_dim) +class BandPadding(BasePreprocessor): -class RandomFlip(BaseAugment): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - ud_probability: float, - lr_probability: float, + self, + fill_value: float = 0.0, + **meta ) -> None: - """Initialize the RandomFlip. + """Intialize the BandPadding. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - ud_probability (float): Up/Down augmentation probability. - lr_probability (float): Left/Right augmentation probability. + fill_value (float): fill value for padding. + meta: statistics/info of the input data and target encoder + data_bands: bands of incoming data + encoder_bands: expected bands by encoder """ - super().__init__(dataset, encoder) - self.ud_probability = ud_probability - self.lr_probability = lr_probability + super().__init__() - def __getitem__( - self, index: int + self.fill_value = fill_value + self.data_img_size = meta['data_img_size'] + + self.encoder_bands = meta['encoder_bands'] + self.avail_bands_mask, self.used_bands_indices = {}, {} + for k in meta['encoder_bands'].keys(): + if k in meta['data_bands'].keys(): + self.avail_bands_mask[k] = torch.tensor( + [b in meta['data_bands'][k] for b in meta['encoder_bands'][k]], dtype=torch.bool + ) + self.used_bands_indices[k] = torch.tensor( + [meta['data_bands'][k].index(b) for b in meta['encoder_bands'][k] if b in meta['data_bands'][k]], dtype=torch.long + ) + else: + self.avail_bands_mask[k] = torch.zeros(len(meta['encoder_bands'][k]), dtype=torch.bool) + if not self.used_bands_indices: + raise ValueError("No nontrivial input bands after BandPadding!") + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Random FLIP to the data. - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] - if random.random() < self.lr_probability: - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = torch.fliplr(v) - data["target"] = torch.fliplr(data["target"]) - if random.random() < self.ud_probability: - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = torch.flipud(v) - data["target"] = torch.flipud(data["target"]) + for k in self.avail_bands_mask.keys(): + if k in self.used_bands_indices.keys(): + size = self.avail_bands_mask[k].shape + data["image"][k].shape[1:] + padded_image = torch.full(size, fill_value=self.fill_value, dtype=data["image"][k].dtype) + padded_image[self.avail_bands_mask[k]] = data["image"][k][self.used_bands_indices[k]] + else: + reference = data["image"](list(data["image"].keys())[0]) + size = self.avail_bands_mask[k].shape + reference.shape[1:] + padded_image = torch.full(size, fill_value=self.fill_value, dtype=reference.dtype) + + data["image"][k] = padded_image return data - -class GammaAugment(BaseAugment): + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_bands'] = meta['encoder_bands'] + for k in self.avail_bands_mask.keys(): + size = self.avail_bands_mask[k].shape + meta['data_mean'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + meta['data_std'][k] = torch.ones(size, dtype=torch.float) + meta['data_min'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + meta['data_max'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + if self.used_bands_indices[k] is not None: + meta['data_mean'][k][self.avail_bands_mask[k]] = meta['data_mean'][k][self.used_bands_indices[k]] + meta['data_std'][k][self.avail_bands_mask[k]] = meta['data_std'][k][self.used_bands_indices[k]] + meta['data_min'][k][self.avail_bands_mask[k]] = meta['data_min'][k][self.used_bands_indices[k]] + meta['data_max'][k][self.avail_bands_mask[k]] = meta['data_max'][k][self.used_bands_indices[k]] + return meta + + +class NormalizeMeanStd(BasePreprocessor): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - probability: float, - gamma_range: float, + self, + **meta, ) -> None: - """Initialize the GammaAugment. + """Initialize the NormalizeMeanStd. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - probability (float): probability of applying the augmentation. - gamma_range (float): gamma range. + meta: statistics/info of the input data and target encoder + data_mean: global mean of incoming data + data_std: global std of incoming data """ - super().__init__(dataset, encoder) - self.probability = probability - self.gamma_range = gamma_range + super().__init__() + + self.data_mean = meta['data_mean'] + self.data_std = meta['data_std'] - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Gamma Augment to the data. + """Apply Mean/Std Normalization to the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -553,43 +286,42 @@ def __getitem__( "metadata": dict}. """ - data = self.dataset[index] - # WARNING: Test this bit of code - if random.random() < self.probability: - for k, v in data["image"].items() and k in self.encoder.input_bands: - data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range)) + for k in self.data_mean.keys(): + data["image"][k].sub_(self.data_mean[k].view(-1, 1, 1, 1)).div_(self.data_std[k].view(-1, 1, 1, 1)) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_mean'] = {k: torch.zeros_like(v) for k, v in meta['data_mean'].items()} + meta['data_std'] = {k: torch.ones_like(v) for k, v in meta['data_std'].items()} + meta['data_min'] = {k: (v - meta['data_mean'][k]) / meta['data_std'][k] for k, v in meta['data_min'].items()} + meta['data_max'] = {k: (v - meta['data_mean'][k]) / meta['data_std'][k] for k, v in meta['data_max'].items()} + + return meta -class NormalizeMeanStd(BaseAugment): + +class NormalizeMinMax(BasePreprocessor): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder + self, + **meta, ) -> None: """Initialize the NormalizeMeanStd. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. + meta: statistics/info of the input data and target encoder + data_min: global maximum value of incoming data + data_sax: global minimum value of incoming data """ - super().__init__(dataset, encoder) - self.data_mean_tensors = {} - self.data_std_tensors = {} - # Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use - for modality in self.encoder.input_bands: - self.data_mean_tensors[modality] = torch.tensor( - self.dataset.data_mean[modality] - ).reshape((-1, 1, 1, 1)) - self.data_std_tensors[modality] = torch.tensor( - self.dataset.data_std[modality] - ).reshape((-1, 1, 1, 1)) - - def __getitem__( - self, index: int + super().__init__() + + self.data_min = meta['data_min'] + self.data_max = meta['data_max'] + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: """Apply Mean/Std Normalization to the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -601,49 +333,87 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - for modality in self.encoder.input_bands: - data["image"][modality] = ( - data["image"][modality] - self.data_mean_tensors[modality] - ) / self.data_std_tensors[modality] + + for k in self.data_min.keys(): + size = (-1,) + data["image"][k].shape[1:] + data["image"][k].sub_(self.data_min[k].view(size)).div_((self.data_max[k]-self.data_min[k]).view(size)) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_mean'] = {k: (v - meta['data_min'][k]) / (meta['data_max'][k] - meta['data_min'][k]) for k, v in meta['data_mean'].items()} + meta['data_std'] = {k: v / (meta['data_max'][k] - meta['data_min'][k]) for k, v in meta['data_std'].items()} + meta['data_min'] = {k: torch.zeros_like(v) for k, v in meta['data_mean'].items()} + meta['data_max'] = {k: torch.ones_like(v) for k, v in meta['data_std'].items()} -class NormalizeMinMax(BaseAugment): + return meta + + +class RandomCrop(BasePreprocessor): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - data_min: torch.Tensor, - data_max: torch.Tensor, + size: int | Sequence[int], + pad_if_needed: bool = False, + **meta ) -> None: - """Apply Min/Max Normalization to scale the data to the range [data_min, data_max]. + """Initialize the RandomCrop preprocessor. + Args: + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding + """ + super().__init__() + + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.pad_if_needed = pad_if_needed + self.pad_value = meta['data_mean'] + self.ignore_index = meta['ignore_index'] + + def get_params(self, data: dict) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - data_min (torch.Tensor): data_min scalar tensor. - data_max (torch.Tensor): data_max scalar tensor. + data (dict): input data. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - super().__init__(dataset, encoder) - self.normalizers = {} - self.data_min_tensors = {} - self.data_max_tensors = {} - self.min = data_min - self.max = data_max - for modality in self.encoder.input_bands: - self.data_min_tensors[modality] = torch.tensor( - self.dataset.data_min[modality] - ).reshape((-1, 1, 1, 1)) - self.data_max_tensors[modality] = torch.tensor( - self.dataset.data_max[modality] - ).reshape((-1, 1, 1, 1)) - - def __getitem__( - self, index: int + h, w = data["image"][list(data["image"].keys())[0]].shape[-2:] + th, tw = self.size + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return i, j, th, tw + + def check_pad(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]], + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + _, t, height, width = data["image"][list(data["image"].keys())[0]].shape + + if height < self.size[0] or width < self.size[1]: + pad_img = max(self.size[0] - height, 0), max(self.size[1] - width, 0) + height, width = height + 2 * pad_img[0], width + 2 * pad_img[1] + for k, v in data["image"].items(): + padded_img = self.pad_value[k].reshape(-1, 1, 1, 1).repeat(1, t, height, width) + padded_img[:, :, pad_img[0]:-pad_img[0], pad_img[1]:-pad_img[1]] = v + data["image"][k] = padded_img + + data["target"] = TF.pad(data["target"], padding=padded_img, fill=self.ignore_index, padding_mode='constant') + + return data + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Min/Max Normalization to the data. + """Random crop the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -655,149 +425,142 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ + self.check_size(data) + + if self.pad_if_needed: + data = self.check_pad(data) + + i, j, h, w = self.get_params(data=data) + + for k, v in data["image"].items(): + data["image"][k] = TF.crop(v, i, j, h, w) + + data["target"] = TF.crop(data["target"], i, j, h, w) - data = self.dataset[index] - for modality in self.encoder.input_bands: - data["image"][modality] = ( - (data["image"][modality] - self.data_min_tensors[modality]) - * (self.max - self.min) - - self.min - ) / self.data_max_tensors[modality] - return data - + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta -class ColorAugmentation(BaseAugment): + +class RandomCropToEncoder(RandomCrop): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - brightness: float = 0, - contrast: float = 0, - clip: bool = False, - br_probability: float = 0, - ct_probability: float = 0, + pad_if_needed: bool = False, + **meta ) -> None: - """Initialize the ColorAugmentation. + """Initialize the RandomCropToEncoder preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - brightness (float, optional): brightness parameter. Defaults to 0. - contrast (float, optional): contrast. Defaults to 0. - clip (bool, optional): clip parameter. Defaults to False. - br_probability (float, optional): brightness augmentation probability. - Defaults to 0. - ct_probability (float, optional): contrast augmentation probability. - Defaults to 0. + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) - self.brightness = brightness - self.contrast = contrast - self.clip = clip - self.br_probability = br_probability - self.ct_probability = ct_probability - - def adjust_brightness( - self, image: torch.Tensor, factor: float | torch.Tensor, clip_output: bool - ) -> torch.Tensor: - """Adjust the brightness of the image. + size = meta['encoder_input_size'] + super().__init__( + size, pad_if_needed, **meta + ) + +class FocusRandomCrop(RandomCrop): + def __init__( + self, + size: int, + pad_if_needed: bool = False, + **meta + ) -> None: + """Initialize the FocusRandomCrop preprocessor. Args: - image (torch.Tensor): input image of shape (C T H W) (T=1 if single timeframe). - factor (float | torch.Tensor): adjustment factor. - clip_output (bool): whether to clip the output. - Returns: - torch.Tensor: output image of shape (C T H W) (T=1 if single timeframe). + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - if isinstance(factor, float): - factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype) - while len(factor.shape) != len(image.shape): - factor = factor[..., None] - - img_adjust = image + factor - if clip_output: - img_adjust = img_adjust.clamp(min=-1.0, max=1.0) + super().__init__( + size, + pad_if_needed, + **meta) - return img_adjust + def get_params(self, data: dict) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. - def adjust_contrast( - self, image: torch.Tensor, factor: torch.Tensor | float, clip_output: bool - ) -> torch.Tensor: - """Adjust the contrast of the image. Args: - image (torch.Tensor): image input of shape (C T H W) (T=1 if single timeframe). - factor (torch.Tensor | float): augmentation factor. - clip_output (bool): whether to clip the output. + data (dict): input data. Returns: - torch.Tensor: output image of shape (C T H W) (T=1 if single timeframe). + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - if isinstance(factor, float): - factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype) - while len(factor.shape) != len(image.shape): - factor = factor[..., None] - assert factor >= 0, "Contrast factor must be positive" - img_adjust = image * factor - if clip_output: - img_adjust = img_adjust.clamp(min=-1.0, max=1.0) + h, w = data["target"].shape + th, tw = self.size - return img_adjust + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply ColorAugmentation to the data. + if w == tw and h == th: + return 0, 0, h, w + + valid_map = data["target"] != self.ignore_index + idx = torch.arange(0, h*w)[valid_map.flatten()] + sample = idx[random.randint(0, idx.shape[0] - 1)] + y, x = sample // w, sample % w + + i = random.randint(max(0, y - th), min(y, h - th + 1)) + j = random.randint(max(0, x - tw), min(x, w - tw + 1)) + + return i, j, th, tw + + +class FocusRandomCropToEncoder(FocusRandomCrop): + def __init__( + self, + pad_if_needed: bool = False, + **meta + ) -> None: + """Initialize the FocusRandomCropToEncoder preprocessor. Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - data = self.dataset[index] - for k, _ in data["image"].items(): - if k in self.encoder.input_bands: - brightness = random.uniform(-self.brightness, self.brightness) - if random.random() < self.br_probability: - data["image"][k] = self.adjust_brightness( - data["image"][k], brightness, self.clip - ) - - for k, _ in data["image"].items(): - if k in self.encoder.input_bands: - if random.random() < self.ct_probability: - contrast = random.uniform(1 - self.contrast, 1 + self.contrast) - data["image"][k] = self.adjust_contrast( - data["image"][k], contrast, self.clip - ) - - return data + size = meta['encoder_input_size'] + super().__init__( + size, pad_if_needed, **meta + ) -class Resize(BaseAugment): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder, size: int) -> None: - """Initialize the Resize augmentation. +class ImportanceRandomCrop(RandomCrop): + def __init__( + self, + size, + pad_if_needed: bool = False, + num_trials: int = 10, + **meta + ) -> None: + """Initialize the FocusRandomCrop preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): size of the output image. + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + num_trials (int, optional): number of trials. Defaults to 10. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) + super().__init__(size, pad_if_needed, **meta) + + self.num_trials = num_trials + self.class_weight = 1 / meta["class_distribution"] - self.size = (size, size) - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Resize the data. + """Random crop the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -809,70 +572,93 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.Resize(self.size, interpolation=T.InterpolationMode.BILINEAR, antialias=True)(v) - - if data["target"].ndim == 2: - data["target"] = data["target"].unsqueeze(0) - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) - data["target"] = data["target"].squeeze(0) + self.check_size(data) + + if self.pad_if_needed: + data, height, width = self.check_pad(data) else: - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + _, _, height, width = data["image"][list(data["image"].keys())[0]].shape - return data + valid = data["target"] != self.ignore_index + weight_map = torch.full(size=data["target"].shape, fill_value=1e-6, dtype=torch.float) + weight_map[valid] = self.class_weight[data["target"][valid]] + crop_candidates = [self.get_params(data) for _ in range(self.num_trials)] + crop_weights = [weight_map[i:i+h, j:j+w].sum().item()/(h*w) for i, j, h, w in crop_candidates] + crop_weights = np.array(crop_weights) + crop_weights = crop_weights / crop_weights.sum() -class ResizeToEncoder(Resize): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the ResizeToEncoder augmentation. - Resize input data to the encoder input size. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - super().__init__(dataset, encoder, encoder.input_size) + crop_idx = np.random.choice(self.num_trials, p=crop_weights) + i, j, h, w = crop_candidates[crop_idx] + + for k, v in data["image"].items(): + data["image"][k] = TF.crop(v, i, j, h, w) + data["target"] = TF.crop(data["target"], i, j, h, w) -class RandomCrop(BaseAugment): + return data + + +class ImportanceRandomCropToEncoder(ImportanceRandomCrop): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - size: int, - padding: str | None = None, pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", + num_trials: int = 10, + **meta ) -> None: - """Initialize the RandomCrop augmentation. + """Initialize the FocusRandomCrop preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". + num_trials (int, optional): number of trials. Defaults to 10. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode + size = meta["encoder_input_size"] + super().__init__(size, pad_if_needed, num_trials, **meta) + + +class Resize(BasePreprocessor): + def __init__( + self, + size: int | Sequence[int], + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the Resize preprocessor. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + super().__init__() + + if not isinstance(size, (int, Sequence)): + raise TypeError(f"Size should be int or sequence. Got {type(size)}") + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def __getitem__( - self, index: int + if isinstance(interpolation, int): + interpolation = TF._interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + self.resize_target = resize_target + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Random crop the data. + """Resize the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -884,145 +670,186 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - # Use the first image to determine parameters - i, j, h, w = T.RandomCrop.get_params( - data["image"][list(data["image"].keys())[0]], - output_size=(self.size, self.size), - ) for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.functional.crop(v, i, j, h, w) - data["target"] = T.functional.crop(data["target"], i, j, h, w) + data["image"][k] = TF.resize(data["image"][k], self.size, interpolation=self.interpolation, antialias=self.antialias) + + if self.resize_target: + if torch.is_floating_point(data["target"]): + data["target"] = TF.resize(data["target"].unsqueeze(0), size=self.size, interpolation=T.InterpolationMode.BILINEAR).squeeze(0) + else: + data["target"] = TF.resize(data["target"].unsqueeze(0), size=self.size, interpolation=T.InterpolationMode.NEAREST).squeeze(0) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta -class RandomCropToEncoder(RandomCrop): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - ) -> None: - """Initialize the RandomCropToEncoder augmentation. - Apply RandomCrop to the encoder input size. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad or not. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - """ - size = encoder.input_size - super().__init__( - dataset, encoder, size, padding, pad_if_needed, fill, padding_mode - ) +class ResizeToEncoder(Resize): + def __init__(self, + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the ResizeToEncoder preprocessor. + Args: + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + size = meta['encoder_input_size'] + super().__init__(size, interpolation, antialias, resize_target, **meta) + + +class RandomResizedCrop(BasePreprocessor): + def __init__(self, + size: int | Sequence[int], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (0.75, 1.3333333333333333), + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + )-> None: + """Initialize the RandomResizedCrop preprocessor. + Args: + size (int): crop size. + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + raise ValueError("Scale and ratio should be of kind (min, max)") + + if isinstance(interpolation, int): + interpolation = TF._interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + self.scale = scale + self.ratio = ratio + self.resize_target = resize_target + + @staticmethod + def get_params(input_size: Tuple[int, int], scale: Tuple[float, float], ratio: Tuple[float, float]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image or Tensor): Input image. + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped -class ImportanceRandomCrop(BaseAugment): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - size: int, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - n_crops: int = 10, - ) -> None: - """Initialize the ImportanceRandomCrop. + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + height, width = input_size + area = height * width + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - n_crops (int, optional): number of crops. Defaults to 10. - """ - super().__init__(dataset, encoder) - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - self.n_crops = n_crops - - def __getitem__(self, index): - data = self.dataset[index] - - # dataset needs to provide a weighting layer - weight = torch.zeros_like(data["target"]).float() - for i, freq in enumerate(self.distribution): - weight[data["target"] == i] = 1 - freq - weight[data["target"] == self.ignore_index] = 1e-6 - # candidates for random crop - crop_candidates, crop_weights = [], [] - for _ in range(self.n_crops): - i, j, h, w = T.RandomCrop.get_params( - data["image"][ - list(data["image"].keys())[0] - ], # Use the first image to determine parameters - output_size=(self.size, self.size), - ) - crop_candidates.append((i, j, h, w)) + self.check_size(data) - crop_weight = T.functional.crop(weight, i, j, h, w) - crop_weights.append(torch.sum(crop_weight).item()) + _, t, h_img, w_img = data["image"][list(data["image"].keys())[0]].shape - crop_weights = np.array(crop_weights) - crop_weights = crop_weights / np.sum(crop_weights) - crop_idx = np.random.choice(self.n_crops, p=crop_weights) - i, j, h, w = crop_candidates[crop_idx] + i, j, h, w = self.get_params((h_img, w_img), self.scale, self.ratio) for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.functional.crop(v, i, j, h, w) - data["target"] = T.functional.crop(data["target"], i, j, h, w) + data["image"][k] = TF.resized_crop(data["image"][k], i, j, h, w, self.size, self.interpolation, antialias=self.antialias) - return data + if self.resize_target: + if torch.is_floating_point(data["target"]): + data["target"] = TF.resized_crop(data["target"].unsqueeze(0), i, j, h, w, self.size, T.InterpolationMode.BILINEAR).squeeze(0) + else: + data["target"] = TF.resized_crop(data["target"].unsqueeze(0), i, j, h, w, self.size, T.InterpolationMode.NEAREST).squeeze(0) + else: + data["target"] = TF.crop(data["target"], i, j, h, w) -class ImportanceRandomCropToEncoder(ImportanceRandomCrop): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - n_crops: int = 10, - ) -> None: - """Initialize the ImportanceRandomCropToEncoder. - Initialize the ImportanceRandomCrop augmentation to the encoder input size. + return data - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - n_crops (int, optional): number of crops. Defaults to 10. - """ - size = encoder.input_size - super().__init__( - dataset=dataset, - encoder=encoder, - size=size, - padding=padding, - pad_if_needed=pad_if_needed, - fill=fill, - padding_mode=padding_mode, - n_crops=n_crops, - ) + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta + +class RandomResizedCropToEncoder(RandomResizedCrop): + def __init__(self, + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (0.75, 1.3333333333333333), + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the RandomResizedCropToEncoder preprocessor. + Args: + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + size = meta['encoder_input_size'] + super().__init__(size, scale, ratio, interpolation, antialias, resize_target, **meta) + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size \ No newline at end of file diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index bfe6824c..5e6ec7ea 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -2,6 +2,8 @@ import os import time from pathlib import Path +import math +import wandb import torch import torch.nn.functional as F @@ -38,33 +40,34 @@ class Evaluator: """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ) -> None: + self.rank = int(os.environ["RANK"]) self.val_loader = val_loader self.logger = logging.getLogger() self.exp_dir = exp_dir self.device = device + self.inference_mode = inference_mode + self.sliding_inference_batch = sliding_inference_batch self.classes = self.val_loader.dataset.classes self.split = self.val_loader.dataset.split self.ignore_index = self.val_loader.dataset.ignore_index self.num_classes = len(self.classes) self.max_name_len = max([len(name) for name in self.classes]) - self.use_wandb = use_wandb - - if use_wandb: - import wandb - self.wandb = wandb + self.use_wandb = use_wandb def evaluate( - self, - model: torch.nn.Module, - model_name: str, - model_ckpt_path: str | Path | None = None, + self, + model: torch.nn.Module, + model_name: str, + model_ckpt_path: str | Path | None = None, ) -> None: raise NotImplementedError @@ -77,6 +80,53 @@ def compute_metrics(self): def log_metrics(self, metrics): pass + @staticmethod + def sliding_inference(model, img, input_size, output_shape=None, stride=None, max_batch=None): + b, c, t, height, width = img[list(img.keys())[0]].shape + + if stride is None: + h = int(math.ceil(height / input_size)) + w = int(math.ceil(width / input_size)) + else: + h = math.ceil((height - input_size) / stride) + 1 + w = math.ceil((width - input_size) / stride) + 1 + + h_grid = torch.linspace(0, height - input_size, h).round().long() + w_grid = torch.linspace(0, width - input_size, w).round().long() + num_crops_per_img = h * w + + for k, v in img.items(): + img_crops = [] + for i in range(h): + for j in range(w): + img_crops.append(v[:, :, :, h_grid[i]:h_grid[i] + input_size, w_grid[j]:w_grid[j] + input_size]) + img[k] = torch.cat(img_crops, dim=0) + + pred = [] + max_batch = max_batch if max_batch is not None else b * num_crops_per_img + batch_num = int(math.ceil(b * num_crops_per_img / max_batch)) + for i in range(batch_num): + img_ = {k: v[max_batch * i: min(max_batch * i + max_batch, b * num_crops_per_img)] for k, v in img.items()} + pred_ = model.forward(img_, output_shape=(input_size, input_size)) + pred.append(pred_) + pred = torch.cat(pred, dim=0) + pred = pred.view(num_crops_per_img, b, -1, input_size, input_size).transpose(0, 1) + + merged_pred = torch.zeros((b, pred.shape[2], height, width), device=pred.device) + pred_count = torch.zeros((b, height, width), dtype=torch.long, device=pred.device) + for i in range(h): + for j in range(w): + merged_pred[:, :, h_grid[i]:h_grid[i] + input_size, + w_grid[j]:w_grid[j] + input_size] += pred[:, h * i + j] + pred_count[:, h_grid[i]:h_grid[i] + input_size, + w_grid[j]:w_grid[j] + input_size] += 1 + + merged_pred = merged_pred / pred_count.unsqueeze(1) + if output_shape is not None: + merged_pred = F.interpolate(merged_pred, size=output_shape, mode="bilinear") + + return merged_pred + class SegEvaluator(Evaluator): """ @@ -100,13 +150,15 @@ class SegEvaluator(Evaluator): """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ): - super().__init__(val_loader, exp_dir, device, use_wandb) + super().__init__(val_loader, exp_dir, device, inference_mode, sliding_inference_batch, use_wandb) @torch.no_grad() def evaluate(self, model, model_name='model', model_ckpt_path=None): @@ -129,11 +181,19 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): ) for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): + image, target = data["image"], data["target"] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) - logits = model(image, output_shape=target.shape[-2:]) + if self.inference_mode == "sliding": + input_size = model.module.encoder.input_size + logits = self.sliding_inference(model, image, input_size, output_shape=target.shape[-2:], + max_batch=self.sliding_inference_batch) + elif self.inference_mode == "whole": + logits = model(image, output_shape=target.shape[-2:]) + else: + raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) if logits.shape[1] == 1: pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1) else: @@ -141,14 +201,14 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): valid_mask = target != self.ignore_index pred, target = pred[valid_mask], target[valid_mask] count = torch.bincount( - (pred * self.num_classes + target), minlength=self.num_classes**2 + (pred * self.num_classes + target), minlength=self.num_classes ** 2 ) confusion_matrix += count.view(self.num_classes, self.num_classes) torch.distributed.all_reduce( confusion_matrix, op=torch.distributed.ReduceOp.SUM ) - metrics = self.compute_metrics(confusion_matrix) + metrics = self.compute_metrics(confusion_matrix.cpu()) self.log_metrics(metrics) used_time = time.time() - t @@ -200,16 +260,16 @@ def log_metrics(self, metrics): def format_metric(name, values, mean_value): header = f"------- {name} --------\n" metric_str = ( - "\n".join( - c.ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % num) - for c, num in zip(self.classes, values) - ) - + "\n" + "\n".join( + c.ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % num) + for c, num in zip(self.classes, values) + ) + + "\n" ) mean_str = ( - "-------------------\n" - + "Mean".ljust(self.max_name_len, " ") - + "\t{:>7}".format("%.3f" % mean_value) + "-------------------\n" + + "Mean".ljust(self.max_name_len, " ") + + "\t{:>7}".format("%.3f" % mean_value) ) return header + metric_str + mean_str @@ -231,7 +291,7 @@ def format_metric(name, values, mean_value): self.logger.info(macc_str) if self.use_wandb: - self.wandb.log( + wandb.log( { f"{self.split}_mIoU": metrics["mIoU"], f"{self.split}_mF1": metrics["mF1"], @@ -274,18 +334,20 @@ class RegEvaluator(Evaluator): """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ): - super().__init__(val_loader, exp_dir, device, use_wandb) + super().__init__(val_loader, exp_dir, device, inference_mode, sliding_inference_batch, use_wandb) @torch.no_grad() def evaluate(self, model, model_name='model', model_ckpt_path=None): t = time.time() - + if model_ckpt_path is not None: model_dict = torch.load(model_ckpt_path, map_location=self.device) model_name = os.path.basename(model_ckpt_path).split('.')[0] @@ -299,17 +361,29 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): model.eval() tag = f'Evaluating {model_name} on {self.split} set' - mse = 0.0 + + mse = torch.zeros(1, device=self.device) + for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): image, target = data['image'], data['target'] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) - logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1) - loss = F.mse_loss(logits, target) - mse +=loss - mse_avg = mse / len(self.val_loader) - metrics = {"MSE" : mse_avg.item(), "RMSE" : torch.sqrt(mse_avg).item()} + if self.inference_mode == "sliding": + input_size = model.module.encoder.input_size + logits = self.sliding_inference(model, image, input_size, output_shape=target.shape[-2:], + max_batch=self.sliding_inference_batch) + elif self.inference_mode == "whole": + logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1) + else: + raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) + + mse += F.mse_loss(logits, target, reduction='sum') + + torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) + mse = mse / len(self.val_loader.dataset) + + metrics = {"MSE": mse.item(), "RMSE": torch.sqrt(mse).item()} self.log_metrics(metrics) used_time = time.time() - t @@ -322,9 +396,9 @@ def __call__(self, model, model_name='model', model_ckpt_path=None): def log_metrics(self, metrics): header = "------- MSE and RMSE --------\n" - mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE'])+'\n' + mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE']) + '\n' rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) - self.logger.info(header+mse+rmse) + self.logger.info(header + mse + rmse) if self.use_wandb: - self.wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) + wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index d10cec40..5aebcdae 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -141,6 +141,7 @@ def train_one_epoch(self, epoch: int) -> None: image, target = data["image"], data["target"] image = {modality: value.to(self.device) for modality, value in image.items()} target = target.to(self.device) + self.training_stats["data_time"].update(time.time() - end_time) with torch.autocast( @@ -151,17 +152,19 @@ def train_one_epoch(self, epoch: int) -> None: self.optimizer.zero_grad() - if not torch.isnan(loss): - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - self.training_stats['loss'].update(loss.item()) - with torch.no_grad(): - self.compute_logging_metrics(logits, target) - if (batch_idx + 1) % self.log_interval == 0: - self.log(batch_idx + 1, epoch) - else: - self.logger.warning("Skip batch {} because of nan loss".format(batch_idx + 1)) + if not torch.isfinite(loss): + raise FloatingPointError( + f"Rank {self.rank} got infinite/NaN loss at batch {batch_idx} of epoch {epoch}!" + ) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + self.training_stats['loss'].update(loss.item()) + with torch.no_grad(): + self.compute_logging_metrics(logits, target) + if (batch_idx + 1) % self.log_interval == 0: + self.log(batch_idx + 1, epoch) self.lr_scheduler.step() @@ -216,6 +219,7 @@ def save_model( checkpoint (dict[str, dict | int] | None, optional): already prepared checkpoint dict. Defaults to None. """ if self.rank != 0: + torch.distributed.barrier() return checkpoint = self.get_checkpoint(epoch) if checkpoint is None else checkpoint suffix = "_best" if is_best else "_final" if is_final else "" @@ -224,6 +228,8 @@ def save_model( self.logger.info( f"Epoch {epoch} | Training checkpoint saved at {checkpoint_path}" ) + torch.distributed.barrier() + return def load_model(self, resume_path: str | pathlib.Path) -> None: """Load model from the checkpoint. @@ -307,9 +313,8 @@ def log(self, batch_idx: int, epoch) -> None: left_batch_all = ( self.batch_per_epoch * (self.n_epochs - epoch - 1) + left_batch_this_epoch ) - left_eval_times = ( - self.n_epochs + 0.5 - ) // self.eval_interval - self.training_stats["eval_time"].count + left_eval_times = ((self.n_epochs - 0.5) // self.eval_interval + 2 + - self.training_stats["eval_time"].count) left_time_this_epoch = sec_to_hm( left_batch_this_epoch * self.training_stats["batch_time"].avg ) diff --git a/pangaea/run.py b/pangaea/run.py index a5e3687d..58a11078 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -1,7 +1,6 @@ import os as os import pathlib import pprint -import random import time import hydra @@ -10,26 +9,26 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from pangaea.datasets.base import GeoFMDataset, GeoFMSubset, RawGeoFMDataset from pangaea.decoders.base import Decoder from pangaea.encoders.base import Encoder from pangaea.engine.evaluator import Evaluator from pangaea.engine.trainer import Trainer from pangaea.utils.collate_fn import get_collate_fn from pangaea.utils.logger import init_logger +from pangaea.utils.subset_sampler import get_subset_indices from pangaea.utils.utils import ( fix_seed, get_best_model_ckpt_path, get_generator, seed_worker, ) -from pangaea.utils.subset_sampler import get_subset_indices -from pangaea.datasets.base import GeoFMSubset -def get_exp_info(hydra_config: HydraConf) -> str: +def get_exp_info(hydra_config: HydraConf) -> dict[str, str]: """Create a unique experiment name based on the choices made in the config. Args: @@ -121,12 +120,6 @@ def main(cfg: DictConfig) -> None: logger.info("The experiment is stored in %s\n" % exp_dir) logger.info(f"Device used: {device}") - # get datasets - train_dataset: Dataset = instantiate(cfg.dataset, split="train") - val_dataset: Dataset = instantiate(cfg.dataset, split="val") - test_dataset: Dataset = instantiate(cfg.dataset, split="test") - logger.info("Built {} dataset.".format(cfg.dataset.dataset_name)) - encoder: Encoder = instantiate(cfg.encoder) encoder.load_encoder_weights(logger) logger.info("Built {}.".format(encoder.model_name)) @@ -151,34 +144,54 @@ def main(cfg: DictConfig) -> None: # training if train_run: + # get preprocessor + train_preprocessor = instantiate( + cfg.preprocessing.train, + dataset_cfg=cfg.dataset, + encoder_cfg=cfg.encoder, + _recursive_=False, + ) + val_preprocessor = instantiate( + cfg.preprocessing.val, + dataset_cfg=cfg.dataset, + encoder_cfg=cfg.encoder, + _recursive_=False, + ) - for preprocess in cfg.preprocessing.train: - train_dataset: Dataset = instantiate( - preprocess, dataset=train_dataset, encoder=encoder - ) - for preprocess in cfg.preprocessing.test: - val_dataset: Dataset = instantiate( - preprocess, dataset=val_dataset, encoder=encoder - ) + # get datasets + raw_train_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="train") + raw_val_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="val") + train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor) + val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor) + + logger.info("Built {} dataset.".format(cfg.dataset.dataset_name)) if 0 < cfg.limited_label_train < 1: indices = get_subset_indices( - train_dataset, task=task_name, strategy=cfg.limited_label_strategy, - label_fraction=cfg.limited_label_train, num_bins=cfg.stratification_bins, logger=logger + train_dataset, + task=task_name, + strategy=cfg.limited_label_strategy, + label_fraction=cfg.limited_label_train, + num_bins=cfg.stratification_bins, + logger=logger, ) train_dataset = GeoFMSubset(train_dataset, indices) - + if 0 < cfg.limited_label_val < 1: indices = get_subset_indices( - val_dataset, task=task_name, strategy=cfg.limited_label_strategy, - label_fraction=cfg.limited_label_val, num_bins=cfg.stratification_bins, logger=logger + val_dataset, + task=task_name, + strategy=cfg.limited_label_strategy, + label_fraction=cfg.limited_label_val, + num_bins=cfg.stratification_bins, + logger=logger, ) val_dataset = GeoFMSubset(val_dataset, indices) - + logger.info( - f"Total number of train patches: {len(train_dataset)}\n" - f"Total number of validation patches: {len(val_dataset)}\n" - ) + f"Total number of train patches: {len(train_dataset)}\n" + f"Total number of validation patches: {len(val_dataset)}\n" + ) # get train val data loaders train_loader = DataLoader( @@ -198,8 +211,8 @@ def main(cfg: DictConfig) -> None: val_loader = DataLoader( val_dataset, sampler=DistributedSampler(val_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + batch_size=cfg.test_batch_size, + num_workers=cfg.test_num_workers, pin_memory=True, persistent_workers=False, worker_init_fn=seed_worker, @@ -237,16 +250,22 @@ def main(cfg: DictConfig) -> None: trainer.train() # Evaluation - for preprocess in cfg.preprocessing.test: - test_dataset: Dataset = instantiate( - preprocess, dataset=test_dataset, encoder=encoder - ) + test_preprocessor = instantiate( + cfg.preprocessing.test, + dataset_cfg=cfg.dataset, + encoder_cfg=cfg.encoder, + _recursive_=False, + ) + + # get datasets + raw_test_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="test") + test_dataset = GeoFMDataset(raw_test_dataset, test_preprocessor) test_loader = DataLoader( test_dataset, sampler=DistributedSampler(test_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + batch_size=cfg.test_batch_size, + num_workers=cfg.test_num_workers, pin_memory=True, persistent_workers=False, drop_last=False, diff --git a/pangaea/utils/data_types.py b/pangaea/utils/data_types.py deleted file mode 100644 index 77fa4cc9..00000000 --- a/pangaea/utils/data_types.py +++ /dev/null @@ -1,426 +0,0 @@ -from collections import OrderedDict -from typing import List, Dict, Union, Tuple -import torch - - -SENSORS = ["s2", "s1", "l8", "l7", "l5", "l4"] -SENSOR_BANDS = { - "s2": ["b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8", "b8a", "b9", "b11", "b12"], - "s1": ["vv", "vh"], - "l8": ["b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8", "b9", "b10", "b11", "b12"], - "l7": ["b1", "b2", "b3", "b4", "b5", "b7"], - "l5": ["b1", "b2", "b3", "b4", "b5", "b6", "b7"], - "l4": ["b1", "b2", "b3", "b4", "b5", "b6", "b7"], -} - - -class SensorData(OrderedDict): - """ - A class to represent sensor data for a specific sensor. This class extends OrderedDict - to store band data as key-value pairs where the key is the band name and the value is - a torch.Tensor representing the band data. - - Attributes: - sensor (str): The name of the sensor. - bands (List[str]): The list of bands for the sensor. - """ - - def __init__(self, sensor: str) -> None: - """ - Initializes the SensorData object with the given sensor name. - - Args: - sensor (str): The name of the sensor. - """ - super(SensorData, self).__init__() - self.sensor = sensor - self.bands = SENSOR_BANDS[sensor] - - def __setitem__(self, band_name: str, band_data: torch.Tensor) -> None: - """ - Sets the band data for a specific band name. If the band name is "all", it sets - the data for all bands at once. - - Args: - band_name (str): The name of the band. - band_data (torch.Tensor): The data for the band. - - Raises: - ValueError: If the band name is not a string or the band data is not a torch.Tensor. - ValueError: If the band name is "all" and the shape of the band data does not match - the expected number of bands. - ValueError: If the band name is not in the list of valid bands for the sensor. - """ - if not isinstance(band_name, str): - raise ValueError(f"Band name must be a string, got {type(band_name)}") - if not isinstance(band_data, torch.Tensor): - raise ValueError(f"Value must be a torch.Tensor, got {type(band_data)}") - - if band_name == "all": - # all bands are provided all at once - assert ( - len(band_data.shape) == 3 - ), f"Expected 3D tensor, got {band_data.shape}" - if band_data.shape[-3] != len(self.bands): - raise ValueError( - f"Expected {len(self.bands)} bands, got {band_data.shape[-3]}" - ) - for bn, bd in zip(self.bands, band_data): - self.__setitem__(bn, bd) - else: - assert ( - len(band_data.shape) == 2 - ), f"Expected 2D tensor, got {band_data.shape}" - if band_name not in self.bands: - raise ValueError(f"Invalid band name: {band_name}") - super(SensorData, self).__setitem__(band_name, band_data) - - def __getitem__(self, band_name: str) -> torch.Tensor: - """ - Gets the band data for a specific band name. If the band name is "all", it returns - the data for all bands stacked along the first dimension. - - Args: - band_name (str): The name of the band. - - Raises: - ValueError: If the band name is not a string. - - Returns: - torch.Tensor: The data for the specified band or all bands. - """ - if not isinstance(band_name, str): - raise ValueError(f"Band name must be a string, got {type(band_name)}") - - if band_name == "all": - band_data = [] - for bn in self.bands: - if bn not in self.keys(): - # if band not in self.keys(): - # we pad the missing band with zeros - band_data.append(torch.zeros_like(list(self.values())[0])) - else: - band_data.append(self[bn]) - return torch.stack(band_data, dim=-3) - - else: - if band_name not in self.keys(): - return None - return super(SensorData, self).__getitem__(band_name) - - def to(self, device): - """ - Moves all band data to the specified device. - - Args: - device: The device to move the data to. - - Returns: - SensorData: The SensorData object with data moved to the specified device. - """ - for key, value in self.items(): - self.__setitem__(key, value.to(device)) - return self - - def to_dtype(self, dtype): - """ - Converts all band data to the specified dtype. - - Args: - dtype: The dtype to convert the data to. - - Returns: - SensorData: The SensorData object with data converted to the specified dtype. - """ - for key, value in self.items(): - self.__setitem__(key, value.to(dtype)) - return self - - def to_device_dtype(self, device, dtype): - """ - Moves all band data to the specified device and converts it to the specified dtype. - - Args: - device: The device to move the data to. - dtype: The dtype to convert the data to. - - Returns: - SensorData: The SensorData object with data moved to the specified device and - converted to the specified dtype. - """ - for key, value in self.items(): - self.__setitem__(key, value.to(device, dtype)) - return self - -class TimeSerieData(OrderedDict): - def __init__(self, sensor: str) -> None: - super(TimeSerieData, self).__init__() - self.sensor = sensor - self.bands = SENSOR_BANDS[sensor] - - def __setitem__(self, band_name: str, timeserie: torch.Tensor) -> None: - # check if key is a string - if not isinstance(band_name, str): - raise ValueError(f"Band name must be a string, got {type(band_name)}") - if not isinstance(timeserie, torch.Tensor): - raise ValueError(f"Value must be a torch.Tensor, got {type(timeserie)}") - - assert len(timeserie.shape) in [ - 3, - 4, - ], f"Expected 3D or 4D tensor, got {timeserie.shape}" - timeserie_data = [] - for band_data in timeserie: - sd = SensorData(self.sensor) - sd[band_name] = band_data - timeserie_data.append(sd) - - super(TimeSerieData, self).__setitem__(band_name, timeserie_data) - - def __getitem__(self, band_name: str) -> torch.Tensor: - if not isinstance(band_name, str): - raise ValueError(f"Band name must be a string, got {type(band_name)}") - - if band_name == "all": - # get the first key - k = list(self.keys())[0] - sensor_data = super(TimeSerieData, self).__getitem__(k) - elif band_name not in self.keys(): - return None - else: - sensor_data = super(TimeSerieData, self).__getitem__(band_name) - sensor_data = [sd[band_name] for sd in sensor_data] - - return torch.stack(sensor_data, dim=0) - - def to(self, device): - for key, value in self.items(): - timeserie = [data.to(device) for data in value] - super(TimeSerieData, self).__setitem__(key, timeserie) - return self - - def to_dtype(self, dtype): - for key, value in self.items(): - self.__setitem__(key, value.to(dtype)) - return self - - def to_device_dtype(self, device, dtype): - for key, value in self.items(): - self.__setitem__(key, value.to(device, dtype)) - return self - - -class EoTensor(OrderedDict): - """ - A class to represent Earth Observation data. - """ - - def __init__(self, *args, **kwargs) -> None: - super(EoTensor, self).__init__(*args, **kwargs) - self.sensor_data = OrderedDict() - self.timeseries = OrderedDict() - - def __parse_key(self, key: str) -> Tuple[str, str | None, bool]: - # the format of the key is "sensor - [band] - [timeserie]" with the following rules: - # - sensor: one of the supported sensors - # - band: one of the bands of the sensor (optional) - # - timeserie: one of the timeseries of the sensor (optional) - # remove leading and trailing whitespaces - key_parts = [kp.strip() for kp in key.split("-")] - sensor = key_parts[0] - assert sensor in SENSORS, f"Invalid sensor name: {sensor}" - if len(key_parts) == 3 and key_parts[2] == "ts": - band = key_parts[1] - timeserie = True - elif len(key_parts) == 2: - # we need to check if the second part is a band or a timeserie - if key_parts[1] == "ts": - timeserie = True - band = "all" - else: - timeserie = False - band = key_parts[1] - assert ( - band in SENSOR_BANDS[sensor] - ), f"Invalid band name: {band} for sensor {sensor}" - else: - band = "all" - timeserie = False - - return sensor, band, timeserie - - def __setitem__(self, key: str, value: torch.Tensor) -> None: - # check if key is a string - if not isinstance(key, str): - raise ValueError(f"Key must be a string, got {type(key)}") - if not isinstance(value, torch.Tensor): - raise ValueError(f"Value must be a torch.Tensor, got {type(value)}") - - sensor, band, timeserie = self.__parse_key(key) - if timeserie: - self.__set_timeserie(sensor, band, value) - else: - self.__set_sensordata(sensor, band, value) - - def __set_sensordata(self, sensor: str, band: str, value: torch.Tensor) -> None: - sd = SensorData(sensor) - sd[band] = value - self.sensor_data[sensor] = sd - - def __set_timeserie(self, sensor: str, band: str, value: torch.Tensor) -> None: - ts = TimeSerieData(sensor) - ts[band] = value - self.timeseries[sensor] = ts - - def __getitem__(self, key: str) -> torch.Tensor: - # check if key is a string - if not isinstance(key, str): - raise ValueError(f"Key must be a string, got {type(key)}") - - sensor, band, timeserie = self.__parse_key(key) - if timeserie: - return self.__get_timeserie(sensor, band) - else: - return self.__get_sensordata(sensor, band) - - def __get_sensordata(self, sensor: str, band: str) -> torch.Tensor: - sd = self.sensor_data[sensor] - return sd[band] - - def __get_timeserie(self, sensor: str, band: str) -> torch.Tensor: - ts = self.timeseries[sensor] - return ts[band] - - def __repr__(self): - return f"EoTensor({super(EoTensor, self).__repr__()})" - - def __str__(self): - return f"EoTensor({super(EoTensor, self).__str__()})" - - def __len__(self): - return super(EoTensor, self).__len__() - - def __iter__(self): - return super(EoTensor, self).__iter__() - - def __contains__(self, key): - return super(EoTensor, self).__contains__(key) - - def keys(self): - return super(EoTensor, self).keys() - - def values(self): - return super(EoTensor, self).values() - - def items(self): - return super(EoTensor, self).items() - - def to(self, device): - for key, value in self.sensor_data.items(): - self.sensor_data[key] = value.to(device) - for key, value in self.timeseries.items(): - self.timeseries[key] = value.to(device) - return self - - def to_dtype(self, dtype): - for key, value in self.items(): - self.__setitem__(key, value.to(dtype)) - return self - - def to_device_dtype(self, device, dtype): - for key, value in self.items(): - self.__setitem__(key, value.to(device, dtype)) - return self - - def to_dict(self): - return dict(self.items()) - - def to_list(self): - return list(self.values()) - - def to_tuple(self): - return tuple(self.values()) - - def to_dict_list(self): - return list(self.items()) - - def to_dict_tuple(self): - return tuple(self.items()) - - def to_dict_dict(self): - return dict(self.items()) - - def to_dict_dict_list(self): - return [dict] - - -if __name__ == "__main__": - # print("Only one band") - # band = torch.rand(256, 256) - # s2 = SensorData("s2") - # s2["b1"] = band - # print(s2["b1"].shape) - # - # print("All bands") - # s2 = SensorData("s2") - # s2_data = torch.rand(12, 256, 256) - # s2["all"] = s2_data - # print(s2["b1"].shape) - # print(s2["all"].shape) - # - # print("Set two bands, get all bands") - # s2 = SensorData("s2") - # b1 = torch.rand(256, 256) - # b2 = torch.rand(256, 256) - # s2["b1"] = b1 - # s2["b2"] = b2 - # print(s2["b1"].shape) - # print(s2["all"].shape) - - print("One band") - t = EoTensor() - t["s2-b1"] = torch.rand(256, 256) - print(t["s2-b1"].shape) - - print("All bands") - t = EoTensor() - t["s2"] = torch.rand(12, 256, 256) - print(t["s2"].shape) - - print("One band timeserie") - t = EoTensor() - t["s2-b1-ts"] = torch.rand(3, 256, 256) - x = t["s2-b1-ts"] - print(x.shape) - - print("All bands timeserie") - t = EoTensor() - t["s2-ts"] = torch.rand(15, 12, 256, 256) - print("timeserie s2 : ", t["s2-ts"].shape) - t["s2"] = torch.rand(12, 32, 32) - print("s2 : ", t["s2"].shape) - - print("CHECK DEVICE") - device = "cuda" if torch.cuda.is_available() else "cpu" - t.to(device) - print("timeserie s2 : ", t["s2-ts"].device) - print("s2 : ", t["s2"].device) - - # ts = TimeSerieData("s2") - # ts["all"] = torch.rand(9, 12, 256, 256) - # print(ts["all"].shape) - # - # x = torch.rand(9, 256, 256) - # print("DEvice test", x.device) - # ts = TimeSerieData("s2") - # ts["b1"] = torch.rand(9, 256, 256) - # - # device = "cuda" if torch.cuda.is_available() else "cpu" - # print("Device:", device) - # ts.to(device) - # - # print(ts["all"].shape) - # print(ts["all"].device) - # - # print(ts["b1"].shape) - # print(ts["b2"]) diff --git a/pangaea/utils/utils.py b/pangaea/utils/utils.py index d19c5e0d..0960f553 100644 --- a/pangaea/utils/utils.py +++ b/pangaea/utils/utils.py @@ -45,3 +45,8 @@ def get_best_model_ckpt_path(exp_dir: str | Path) -> str: return os.path.join( exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) ) + +def get_final_model_ckpt_path(exp_dir: str | Path) -> str: + return os.path.join( + exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_final.pth")) + )