diff --git a/.gitignore b/.gitignore index 93c03efc..b7e439b9 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ pretrained_models/ work-dir/ wandb/ data/* +data +pretrained_models # Scenes for mados tiny # Wow gitignore pattern mathcing is a mess !data/mados diff --git a/configs/dataset/dynamicen.yaml b/configs/dataset/dynamicen.yaml index 7e443271..88ff361a 100644 --- a/configs/dataset/dynamicen.yaml +++ b/configs/dataset/dynamicen.yaml @@ -26,6 +26,14 @@ distribution: - 0.28 - 0.08 +sample_dates: + - 1 + - 5 + - 10 + - 15 + - 20 + - 25 + # data stats bands: optical: @@ -36,15 +44,15 @@ bands: data_mean: optical: - - 1042.59240722656 - - 915.618408203125 - 671.260559082031 + - 915.618408203125 + - 1042.59240722656 - 2605.20922851562 data_std: optical: - - 957.958435058593 - - 715.548767089843 - 596.943908691406 + - 715.548767089843 + - 957.958435058593 - 1059.90319824218 data_min: diff --git a/configs/encoder/croma_joint.yaml b/configs/encoder/croma_joint.yaml index dfc68a67..cfecf55e 100644 --- a/configs/encoder/croma_joint.yaml +++ b/configs/encoder/croma_joint.yaml @@ -28,3 +28,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 1024 diff --git a/configs/encoder/croma_optical.yaml b/configs/encoder/croma_optical.yaml index 8615598d..7f2a7312 100644 --- a/configs/encoder/croma_optical.yaml +++ b/configs/encoder/croma_optical.yaml @@ -24,3 +24,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/croma_sar.yaml b/configs/encoder/croma_sar.yaml index 42925f6b..10dbecd6 100644 --- a/configs/encoder/croma_sar.yaml +++ b/configs/encoder/croma_sar.yaml @@ -14,4 +14,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/dofa.yaml b/configs/encoder/dofa.yaml index 89a087f5..3daf76b0 100644 --- a/configs/encoder/dofa.yaml +++ b/configs/encoder/dofa.yaml @@ -41,3 +41,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/gfmswin.yaml b/configs/encoder/gfmswin.yaml index 5e4c6fd7..a3d564b5 100644 --- a/configs/encoder/gfmswin.yaml +++ b/configs/encoder/gfmswin.yaml @@ -7,7 +7,6 @@ in_chans: 3 t_patch_size: 3 depth: 12 embed_dim: 128 -output_dim: 1024 img_size: 192 # fixed to 192 to avoid interpolation in checkpoints which leads to drop in performance depths: [ 2, 2, 18, 2 ] num_heads: [ 4, 8, 16, 32 ] @@ -28,3 +27,8 @@ output_layers: - 2 - 3 +output_dim: + - 256 + - 512 + - 1024 + - 1024 diff --git a/configs/encoder/prithvi.yaml b/configs/encoder/prithvi.yaml index e7da1e2c..d82d6681 100644 --- a/configs/encoder/prithvi.yaml +++ b/configs/encoder/prithvi.yaml @@ -27,3 +27,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/remoteclip.yaml b/configs/encoder/remoteclip.yaml index 46eb3dfd..3288ed19 100644 --- a/configs/encoder/remoteclip.yaml +++ b/configs/encoder/remoteclip.yaml @@ -20,4 +20,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 768 \ No newline at end of file diff --git a/configs/encoder/satlasnet_mi.yaml b/configs/encoder/satlasnet_mi.yaml index 2383bdf0..ca3ae89f 100644 --- a/configs/encoder/satlasnet_mi.yaml +++ b/configs/encoder/satlasnet_mi.yaml @@ -5,7 +5,6 @@ download_url: null model_identifier: Sentinel2_SwinB_MI_MS # Multi-Image Multi-Spectral fpn: False input_size: 128 -output_dim: 1024 input_bands: optical: @@ -23,3 +22,15 @@ input_bands: - B8 - B11 - B12 + +output_layers: + - 0 + - 1 + - 2 + - 3 + +output_dim: + - 128 + - 256 + - 512 + - 1024 \ No newline at end of file diff --git a/configs/encoder/satlasnet_si.yaml b/configs/encoder/satlasnet_si.yaml index 5759fcb4..d7958c02 100644 --- a/configs/encoder/satlasnet_si.yaml +++ b/configs/encoder/satlasnet_si.yaml @@ -5,7 +5,6 @@ download_url: null model_identifier: Sentinel2_SwinB_SI_MS # Single Image Multi-Spectral fpn: False input_size: 128 -output_dim: 1024 input_bands: optical: @@ -23,3 +22,15 @@ input_bands: - B8 - B11 - B12 + +output_layers: + - 0 + - 1 + - 2 + - 3 + +output_dim: + - 128 + - 256 + - 512 + - 1024 \ No newline at end of file diff --git a/configs/encoder/scalemae.yaml b/configs/encoder/scalemae.yaml index abe99595..4790d512 100644 --- a/configs/encoder/scalemae.yaml +++ b/configs/encoder/scalemae.yaml @@ -25,3 +25,5 @@ output_layers: - 11 - 15 - 23 + +output_dim: 1024 \ No newline at end of file diff --git a/configs/encoder/spectralgpt.yaml b/configs/encoder/spectralgpt.yaml index e880118a..04f4058c 100644 --- a/configs/encoder/spectralgpt.yaml +++ b/configs/encoder/spectralgpt.yaml @@ -3,7 +3,7 @@ encoder_weights: ./pretrained_models/SpectralGPT+.pth download_url: https://zenodo.org/records/8412455/files/SpectralGPT+.pth input_size: 128 -output_dim: 3072 # 768 * (in_chans / t_patch_size) + in_chans: 12 # number of spectral bands t_patch_size: 3 depth: 12 @@ -34,4 +34,5 @@ output_layers: - 7 - 11 +output_dim: 3072 # 768 * (in_chans / t_patch_size) diff --git a/configs/encoder/ssl4eo_data2vec.yaml b/configs/encoder/ssl4eo_data2vec.yaml index 0f6bcb3a..4c38b13b 100644 --- a/configs/encoder/ssl4eo_data2vec.yaml +++ b/configs/encoder/ssl4eo_data2vec.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_dino.yaml b/configs/encoder/ssl4eo_dino.yaml index b69ff2a2..f006119b 100644 --- a/configs/encoder/ssl4eo_dino.yaml +++ b/configs/encoder/ssl4eo_dino.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_mae_optical.yaml b/configs/encoder/ssl4eo_mae_optical.yaml index 51c43eda..2caac534 100644 --- a/configs/encoder/ssl4eo_mae_optical.yaml +++ b/configs/encoder/ssl4eo_mae_optical.yaml @@ -31,4 +31,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_mae_sar.yaml b/configs/encoder/ssl4eo_mae_sar.yaml index a3ac0113..a285940a 100644 --- a/configs/encoder/ssl4eo_mae_sar.yaml +++ b/configs/encoder/ssl4eo_mae_sar.yaml @@ -19,4 +19,6 @@ output_layers: - 3 - 5 - 7 - - 11 \ No newline at end of file + - 11 + +output_dim: 384 \ No newline at end of file diff --git a/configs/encoder/ssl4eo_moco.yaml b/configs/encoder/ssl4eo_moco.yaml index 200e8a08..6f7ddbc2 100644 --- a/configs/encoder/ssl4eo_moco.yaml +++ b/configs/encoder/ssl4eo_moco.yaml @@ -31,3 +31,5 @@ output_layers: - 5 - 7 - 11 + +output_dim: 384 diff --git a/configs/encoder/unet_encoder.yaml b/configs/encoder/unet_encoder.yaml index b3bd2450..1db7e09a 100644 --- a/configs/encoder/unet_encoder.yaml +++ b/configs/encoder/unet_encoder.yaml @@ -6,3 +6,9 @@ topology: [64, 128, 256, 512,] input_bands: ${dataset.bands} +output_dim: + - 64 + - 128 + - 256 + - 512 + diff --git a/configs/preprocessing/reg_default.yaml b/configs/preprocessing/reg_default.yaml index bd199fca..1d468b89 100644 --- a/configs/preprocessing/reg_default.yaml +++ b/configs/preprocessing/reg_default.yaml @@ -1,35 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.RegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax - # overwritten in run - dataset: null - encoder: null - data_min: -1 - data_max: 1 +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.RegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMinMax - # overwritten in run - dataset: null - encoder: null - data_min: -1 - data_max: 1 - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_default.yaml b/configs/preprocessing/seg_default.yaml index d6d40842..1d468b89 100644 --- a/configs/preprocessing/seg_default.yaml +++ b/configs/preprocessing/seg_default.yaml @@ -1,31 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.RandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_importance_crop.yaml b/configs/preprocessing/seg_importance_crop.yaml index f2147762..740e5ab7 100644 --- a/configs/preprocessing/seg_importance_crop.yaml +++ b/configs/preprocessing/seg_importance_crop.yaml @@ -1,31 +1,21 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ImportanceRandomCropToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.ImportanceRandomCropToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.Tile - # overwritten in run - dataset: null - encoder: null +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/preprocessing/seg_resize.yaml b/configs/preprocessing/seg_resize.yaml index aaa37718..625a5e70 100644 --- a/configs/preprocessing/seg_resize.yaml +++ b/configs/preprocessing/seg_resize.yaml @@ -1,32 +1,23 @@ -train: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null +train: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - # overwritten in run - dataset: null - encoder: null - -test: - - _target_: pangaea.engine.data_preprocessor.SegPreprocessor - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd - # overwritten in run - dataset: null - encoder: null - - - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder - # overwritten in run - dataset: null - encoder: null +val: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding +test: + _target_: pangaea.engine.data_preprocessor.Preprocessor + preprocessor_cfg: + - _target_: pangaea.engine.data_preprocessor.ResizeToEncoder + - _target_: pangaea.engine.data_preprocessor.BandFilter + - _target_: pangaea.engine.data_preprocessor.NormalizeMeanStd + - _target_: pangaea.engine.data_preprocessor.BandPadding \ No newline at end of file diff --git a/configs/task/change_detection.yaml b/configs/task/change_detection.yaml index a4a64133..179e85ef 100644 --- a/configs/task/change_detection.yaml +++ b/configs/task/change_detection.yaml @@ -26,5 +26,7 @@ evaluator: exp_dir: null device: null use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 diff --git a/configs/task/regression.yaml b/configs/task/regression.yaml index 623733b2..f596b59b 100644 --- a/configs/task/regression.yaml +++ b/configs/task/regression.yaml @@ -25,4 +25,6 @@ evaluator: val_loader: null exp_dir: null device: null - use_wandb: ${use_wandb} \ No newline at end of file + use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 \ No newline at end of file diff --git a/configs/task/segmentation.yaml b/configs/task/segmentation.yaml index ff1a7404..b41224ad 100644 --- a/configs/task/segmentation.yaml +++ b/configs/task/segmentation.yaml @@ -26,5 +26,7 @@ evaluator: exp_dir: null device: null use_wandb: ${use_wandb} + inference_mode: sliding + sliding_inference_batch: 8 diff --git a/configs/train.yaml b/configs/train.yaml index 239db1f8..d447351c 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,9 +5,11 @@ use_wandb: false wandb_run_id: null # TRAINING -num_workers: 1 -batch_size: 32 +num_workers: 4 +batch_size: 8 +test_num_workers: 4 +test_batch_size: 1 # EXPERIMENT finetune: false diff --git a/pangaea/datasets/ai4smallfarms.py b/pangaea/datasets/ai4smallfarms.py index 8436de13..c4250065 100644 --- a/pangaea/datasets/ai4smallfarms.py +++ b/pangaea/datasets/ai4smallfarms.py @@ -8,10 +8,10 @@ from pyDataverse.api import DataAccessApi, NativeApi from tifffile import imread -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class AI4SmallFarms(GeoFMDataset): +class AI4SmallFarms(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/base.py b/pangaea/datasets/base.py index f611b5c1..c7a40330 100644 --- a/pangaea/datasets/base.py +++ b/pangaea/datasets/base.py @@ -2,7 +2,9 @@ from torch.utils.data import Dataset, Subset import os -class GeoFMDataset(Dataset): +from pangaea.engine.data_preprocessor import Preprocessor + +class RawGeoFMDataset(Dataset): """Base class for all datasets.""" def __init__( @@ -117,30 +119,69 @@ def download(self) -> None: raise NotImplementedError +class GeoFMDataset(Dataset): + """Base class for all datasets.""" + + def __init__( + self, + dataset: RawGeoFMDataset, + preprocessor: Preprocessor = None, + ): + """Initializes the dataset. + + Args: + + """ + super().__init__() + self.__dict__.update(dataset.__dict__) + self.raw_dataset = dataset + self.preprocessor = preprocessor + + + + def __len__(self) -> int: + """Returns the length of the dataset. + + Returns: + int: length of the dataset + """ + + return len(self.raw_dataset) + + def __getitem__(self, i: int) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Returns the i-th item of the dataset. + + Args: + i (int): index of the item + + Raises: + NotImplementedError: raise if the method is not implemented + + Returns: + dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary follwing the format + {"image": + { + "optical": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset), + "sar": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset) + }, + "target": torch.Tensor of shape (H W), + "metadata": dict}. + """ + + output = self.raw_dataset[i] + if self.preprocessor is not None: + output = self.preprocessor(output) + + return output + class GeoFMSubset(Subset): """Custom subset class that retains dataset attributes.""" def __init__(self, dataset, indices): super().__init__(dataset, indices) - + # Copy relevant attributes from the original dataset - self.dataset_name = getattr(dataset, 'dataset_name', None) - self.root_path = getattr(dataset, 'root_path', None) - self.auto_download = getattr(dataset, 'auto_download', None) - self.download_url = getattr(dataset, 'download_url', None) - self.img_size = getattr(dataset, 'img_size', None) - self.multi_temporal = getattr(dataset, 'multi_temporal', None) - self.multi_modal = getattr(dataset, 'multi_modal', None) - self.ignore_index = getattr(dataset, 'ignore_index', None) - self.num_classes = getattr(dataset, 'num_classes', None) - self.classes = getattr(dataset, 'classes', None) - self.distribution = getattr(dataset, 'distribution', None) - self.bands = getattr(dataset, 'bands', None) - self.data_mean = getattr(dataset, 'data_mean', None) - self.data_std = getattr(dataset, 'data_std', None) - self.data_min = getattr(dataset, 'data_min', None) - self.data_max = getattr(dataset, 'data_max', None) - self.split = getattr(dataset, 'split', None) + self.__dict__.update(dataset.__dict__) def filter_by_indices(self, indices): """Apply filtering by indices directly in this subset.""" diff --git a/pangaea/datasets/biomassters.py b/pangaea/datasets/biomassters.py index 8e98b785..b895f06a 100644 --- a/pangaea/datasets/biomassters.py +++ b/pangaea/datasets/biomassters.py @@ -7,7 +7,7 @@ from os.path import join as opj from pangaea.datasets.utils import read_tif -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset def read_imgs(multi_temporal, temp , fname, data_dir, img_size): imgs_s1, imgs_s2, mask = [], [], [] @@ -49,7 +49,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size): imgs_s2 = np.stack(imgs_s2, axis=1) return imgs_s1, imgs_s2, mask -class BioMassters(GeoFMDataset): +class BioMassters(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/croptypemapping.py b/pangaea/datasets/croptypemapping.py index e59e4eec..94fde030 100644 --- a/pangaea/datasets/croptypemapping.py +++ b/pangaea/datasets/croptypemapping.py @@ -12,12 +12,12 @@ import torch -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # @DATASET_REGISTRY.register() -class CropTypeMappingSouthSudan(GeoFMDataset): +class CropTypeMappingSouthSudan(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/fivebillionpixels.py b/pangaea/datasets/fivebillionpixels.py index 4cae19ef..833a4554 100644 --- a/pangaea/datasets/fivebillionpixels.py +++ b/pangaea/datasets/fivebillionpixels.py @@ -19,9 +19,9 @@ import tarfile from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class FiveBillionPixels(GeoFMDataset): +class FiveBillionPixels(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/hlsburnscars.py b/pangaea/datasets/hlsburnscars.py index 0678660e..31b057e0 100644 --- a/pangaea/datasets/hlsburnscars.py +++ b/pangaea/datasets/hlsburnscars.py @@ -14,10 +14,10 @@ import tarfile from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class HLSBurnScars(GeoFMDataset): +class HLSBurnScars(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/mados.py b/pangaea/datasets/mados.py index bcc6f620..be2be97c 100644 --- a/pangaea/datasets/mados.py +++ b/pangaea/datasets/mados.py @@ -18,14 +18,14 @@ import torchvision.transforms as T from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset ############################################################### # MADOS DATASET # ############################################################### # @DATASET_REGISTRY.register() -class MADOS(GeoFMDataset): +class MADOS(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/pastis.py b/pangaea/datasets/pastis.py index 9abc3ac2..d0814e37 100644 --- a/pangaea/datasets/pastis.py +++ b/pangaea/datasets/pastis.py @@ -14,7 +14,7 @@ import torch from einops import rearrange -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset def prepare_dates(date_dict, reference_date): @@ -64,7 +64,7 @@ def split_image(image_tensor, nb_split, id): ].float() -class Pastis(GeoFMDataset): +class Pastis(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/datasets/sen1floods11.py b/pangaea/datasets/sen1floods11.py index 76680c7d..0f4e9193 100644 --- a/pangaea/datasets/sen1floods11.py +++ b/pangaea/datasets/sen1floods11.py @@ -8,9 +8,9 @@ import torch from pangaea.datasets.utils import download_bucket_concurrently -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class Sen1Floods11(GeoFMDataset): +class Sen1Floods11(RawGeoFMDataset): def __init__( self, @@ -153,18 +153,19 @@ def __getitem__(self, index): s2_image = torch.from_numpy(s2_image).float() s1_image = torch.from_numpy(s1_image).float() - target = torch.from_numpy(target) + target = torch.from_numpy(target).long() output = { 'image': { - 'optical': s2_image, - 'sar' : s1_image, + 'optical': s2_image.unsqueeze(1), + 'sar' : s1_image.unsqueeze(1), }, 'target': target, 'metadata': { "timestamp": timestamp, } } + return output @staticmethod @@ -175,4 +176,3 @@ def download(self, silent=False): return download_bucket_concurrently(self.gcs_bucket, self.root_path) - diff --git a/pangaea/datasets/spacenet7.py b/pangaea/datasets/spacenet7.py index 144eb4b3..a34e31c5 100644 --- a/pangaea/datasets/spacenet7.py +++ b/pangaea/datasets/spacenet7.py @@ -23,7 +23,7 @@ from abc import abstractmethod from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # train/val/test split from https://doi.org/10.3390/rs15215135 @@ -98,7 +98,7 @@ # SPACENET 7 DATASET # ############################################################### -class AbstractSN7(GeoFMDataset): +class AbstractSN7(RawGeoFMDataset): def __init__( self, diff --git a/pangaea/datasets/utae_dynamicen.py b/pangaea/datasets/utae_dynamicen.py index 7705e331..8a5ef762 100644 --- a/pangaea/datasets/utae_dynamicen.py +++ b/pangaea/datasets/utae_dynamicen.py @@ -11,12 +11,12 @@ # import random # from PIL import Image -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset # from utils.registry import DATASET_REGISTRY # @DATASET_REGISTRY.register() -class DynamicEarthNet(GeoFMDataset): +class DynamicEarthNet(RawGeoFMDataset): def __init__( self, split: str, @@ -36,6 +36,7 @@ def __init__( data_max: dict[str, list[str]], download_url: str, auto_download: bool, + sample_dates: list[int] ): """Initialize the DynamicEarthNet dataset. Link: https://github.com/aysim/dynnet @@ -101,7 +102,7 @@ def __init__( self.download_url = download_url self.auto_download = auto_download - self.mode = 'weekly' + self.sample_dates = [str(d).rjust(2,'0') for d in sample_dates] self.files = [] @@ -110,6 +111,7 @@ def __init__( self.set_files() + def set_files(self): self.file_list = os.path.join(self.root_path, "dynnet_training_splits", f"{self.split}" + ".txt") @@ -118,97 +120,22 @@ def set_files(self): self.files, self.labels, self.year_months = list(zip(*file_list)) self.files = [f.replace('/reprocess-cropped/UTM-24000/', '/planet/') for f in self.files] - if self.mode == 'daily': - self.all_days = list(range(len(self.files))) - - for i in range(len(self.files)): - self.planet, self.day = [], [] - date_count = 0 - for _, _, infiles in os.walk(os.path.join(self.root_path, self.files[i][1:])): - for infile in sorted(infiles): - if infile.startswith(self.year_months[i]): - self.planet.append(os.path.join(self.files[i], infile)) - self.day.append((datetime(int(str(infile.split('.')[0])[:4]), int(str(infile.split('.')[0][5:7])), - int(str(infile.split('.')[0])[8:])) - self.reference_date).days) - date_count += 1 - self.all_days[i] = list(zip(self.planet, self.day)) - self.all_days[i].insert(0, date_count) - - else: - self.planet, self.day = [], [] - if self.mode == 'weekly': - self.dates = ['01', '05', '10', '15', '20', '25'] - elif self.mode == 'single': - self.dates = ['01'] - - for i, year_month in enumerate(self.year_months): - for date in self.dates: - curr_date = year_month + '-' + date - self.planet.append(os.path.join(self.files[i], curr_date + '.tif')) - self.day.append((datetime(int(str(curr_date)[:4]), int(str(curr_date[5:7])), - int(str(curr_date)[8:])) - self.reference_date).days) - self.planet_day = list(zip(*[iter(self.planet)] * len(self.dates), *[iter(self.day)] * len(self.dates))) - - - def load_data(self, index): - cur_images, cur_dates = [], [] - if self.mode == 'daily': - for i in range(1, self.all_days[index][0]+1): - img = rasterio.open(os.path.join(self.root_path, self.all_days[index][i][0][1:])) - red = img.read(3) - green = img.read(2) - blue = img.read(1) - nir = img.read(4) - image = np.dstack((red, green, blue, nir)) - cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ - cur_dates.append(self.all_days[index][i][1]) - - image_stack = np.concatenate(cur_images, axis=0) - dates = torch.from_numpy(np.array(cur_dates, dtype=np.int32)) - label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) - label = label.read() - mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) - - for i in range(self.num_classes + 1): - if i == 6: - mask[label[i, :, :] == 255] = -1 - else: - mask[label[i, :, :] == 255] = i - - return (image_stack, dates), mask - - else: - for i in range(len(self.dates)): - # read .tif - img = rasterio.open(os.path.join(self.root_path, self.planet_day[index][i][1:])) - red = img.read(3) - green = img.read(2) - blue = img.read(1) - nir = img.read(4) - image = np.dstack((red, green, blue, nir)) - cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ - image_stack = np.concatenate(cur_images, axis=0) - dates = torch.from_numpy(np.array(self.planet_day[index][len(self.dates):], dtype=np.int32)) - label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) - label = label.read() - mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) - - for i in range(self.num_classes + 1): - if i == 6: - mask[label[i, :, :] == 255] = -1 - else: - mask[label[i, :, :] == 255] = i - - return (image_stack, dates), mask + self.all_sequences = [] + for f, ym in zip(self.files, self.year_months): + images = [] + for date in self.sample_dates: + image_file = os.path.join(self.root_path, f[1:], f"{ym}-{date}.npy") + assert os.path.isfile(image_file), f"{image_file} does not exist" + images.append(image_file) + self.all_sequences.append(images) def __len__(self): return len(self.files) def __getitem__(self, index): - (images, dates), label = self.load_data(index) - - images = torch.from_numpy(images).permute(3, 0, 1, 2)#.transpose(0, 1) - label = torch.from_numpy(np.array(label, dtype=np.int32)).long() + images = [np.load(seq) for seq in self.all_sequences[index]] + images = torch.from_numpy(np.stack(images, axis=0)).transpose(0, 1).float() + label = torch.from_numpy(np.load(os.path.join(self.root_path, self.labels[index][1:].replace('tif', 'npy')))).long() output = { 'image': { @@ -219,14 +146,7 @@ def __getitem__(self, index): } return output - #return {'img': images, 'label': label, 'meta': dates} - - # @staticmethod - # def get_splits(dataset_config): - # dataset_train = DynamicEarthNet(cfg=dataset_config, split="train") - # dataset_val = DynamicEarthNet(cfg=dataset_config, split="val") - # dataset_test = DynamicEarthNet(cfg=dataset_config, split="test") - # return dataset_train, dataset_val, dataset_test + @staticmethod def download(self, silent=False): diff --git a/pangaea/datasets/xview2.py b/pangaea/datasets/xview2.py index 845f97c9..12c4bed4 100644 --- a/pangaea/datasets/xview2.py +++ b/pangaea/datasets/xview2.py @@ -15,10 +15,10 @@ import tarfile from pangaea.datasets.utils import DownloadProgressBar -from pangaea.datasets.base import GeoFMDataset +from pangaea.datasets.base import RawGeoFMDataset -class xView2(GeoFMDataset): +class xView2(RawGeoFMDataset): def __init__( self, split: str, diff --git a/pangaea/decoders/ltae.py b/pangaea/decoders/ltae.py index 71ab4950..82e6c7c0 100644 --- a/pangaea/decoders/ltae.py +++ b/pangaea/decoders/ltae.py @@ -152,7 +152,7 @@ def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False): attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute( 0, 1, 4, 2, 3 - ) # head x b x t x h x w + ).contiguous() # head x b x t x h x w if self.return_att: return out, attn diff --git a/pangaea/decoders/upernet.py b/pangaea/decoders/upernet.py index 7361a004..76ed41c7 100644 --- a/pangaea/decoders/upernet.py +++ b/pangaea/decoders/upernet.py @@ -42,14 +42,25 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False + self.input_layers = self.encoder.output_layers + self.input_layers_num = len(self.input_layers) + + self.in_channels = [dim * feature_multiplier for dim in self.encoder.output_dim] + + if self.encoder.pyramid_output: + rescales = [1 for _ in range(self.input_layers_num)] + else: + scales = [4, 2, 1, 0.5] + rescales = [scales[int(i / self.input_layers_num * 4)] for i in range(self.input_layers_num)] + + self.neck = Feature2Pyramid( - embed_dim=encoder.output_dim * feature_multiplier, - rescales=[4, 2, 1, 0.5], + embed_dim=self.in_channels, + rescales=rescales, ) self.align_corners = False - self.in_channels = [encoder.output_dim * feature_multiplier for _ in range(4)] self.channels = channels self.num_classes = num_classes @@ -353,6 +364,8 @@ def __init__( else: raise NotImplementedError + encoder.enforce_single_temporal() + super().__init__( encoder=encoder, num_classes=num_classes, @@ -475,13 +488,24 @@ def __init__( for param in self.encoder.parameters(): param.requires_grad = False + + self.input_layers = self.encoder.output_layers + self.input_layers_num = len(self.input_layers) + + self.in_channels = [dim * feature_multiplier for dim in self.encoder.output_dim] + + if self.encoder.pyramid_output: + rescales = [1 for _ in range(self.input_layers_num)] + else: + scales = [4, 2, 1, 0.5] + rescales = [scales[int(i / self.input_layers_num * 4)] for i in range(self.input_layers_num)] + self.neck = Feature2Pyramid( - embed_dim=encoder.output_dim, rescales=[4, 2, 1, 0.5] + embed_dim=self.in_channels, rescales=rescales ) self.align_corners = False - self.in_channels = [encoder.output_dim for _ in range(4)] self.channels = channels self.num_classes = 1 # regression @@ -745,57 +769,45 @@ class Feature2Pyramid(nn.Module): embed_dims (int): Embedding dimension. rescales (list[float]): Different sampling multiples were used to obtain pyramid features. Default: [4, 2, 1, 0.5]. - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='SyncBN', requires_grad=True). """ def __init__( self, embed_dim, - rescales=[4, 2, 1, 0.5], - norm_cfg=dict(type="SyncBN", requires_grad=True), + rescales=(4, 2, 1, 0.5), ): super().__init__() self.rescales = rescales self.upsample_4x = None - for k in self.rescales: + self.ops = nn.ModuleList() + + for i, k in enumerate(self.rescales): if k == 4: - self.upsample_4x = nn.Sequential( - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), - nn.SyncBatchNorm(embed_dim), + self.ops.append(nn.Sequential( + nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2), + nn.SyncBatchNorm(embed_dim[i]), nn.GELU(), - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), - ) + nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2), + )) elif k == 2: - self.upsample_2x = nn.Sequential( - nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2) - ) + self.ops.append(nn.Sequential( + nn.ConvTranspose2d(embed_dim[i], embed_dim[i], kernel_size=2, stride=2) + )) elif k == 1: - self.identity = nn.Identity() + self.ops.append(nn.Identity()) elif k == 0.5: - self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + self.ops.append(nn.MaxPool2d(kernel_size=2, stride=2)) elif k == 0.25: - self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + self.ops.append(nn.MaxPool2d(kernel_size=4, stride=4)) else: raise KeyError(f"invalid {k} for feature2pyramid") + + def forward(self, inputs): assert len(inputs) == len(self.rescales) outputs = [] - if self.upsample_4x is not None: - ops = [ - self.upsample_4x, - self.upsample_2x, - self.identity, - self.downsample_2x, - ] - else: - ops = [ - self.upsample_2x, - self.identity, - self.downsample_2x, - self.downsample_4x, - ] + for i in range(len(inputs)): - outputs.append(ops[i](inputs[i])) + outputs.append(self.ops[i](inputs[i])) return tuple(outputs) diff --git a/pangaea/encoders/base.py b/pangaea/encoders/base.py index 382c898a..981e7aa7 100644 --- a/pangaea/encoders/base.py +++ b/pangaea/encoders/base.py @@ -57,9 +57,11 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, embed_dim: int, - output_dim: int, + output_layers: list[int], + output_dim: int | list[int], multi_temporal: bool, multi_temporal_output: bool, + pyramid_output: bool, encoder_weights: str | Path, download_url: str, ) -> None: @@ -82,10 +84,13 @@ def __init__( self.input_bands = input_bands self.input_size = input_size self.embed_dim = embed_dim - self.output_dim = output_dim + self.output_layers = output_layers + self.output_dim = [output_dim for _ in output_layers] if isinstance(output_dim, int) else output_dim self.encoder_weights = encoder_weights self.multi_temporal = multi_temporal self.multi_temporal_output = multi_temporal_output + + self.pyramid_output = pyramid_output self.download_url = download_url # download_model if necessary @@ -102,6 +107,11 @@ def load_encoder_weights(self, logger: Logger) -> None: """ raise NotImplementedError + def enforce_single_temporal(self): + return + #self.multi_temporal = False + #self.multi_temporal_fusion = False + def parameters_warning( self, missing: dict[str, torch.Size], diff --git a/pangaea/encoders/croma_encoder.py b/pangaea/encoders/croma_encoder.py index dfea553d..4148a053 100644 --- a/pangaea/encoders/croma_encoder.py +++ b/pangaea/encoders/croma_encoder.py @@ -43,6 +43,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, size="base", ): @@ -52,9 +53,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -73,7 +76,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s2_channels = 12 # fixed at 12 multispectral optical channels self.attn_bias = get_2dalibi( @@ -156,6 +158,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, size="base", ): @@ -165,9 +168,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -186,7 +191,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s1_channels = 2 # fixed at 2 SAR backscatter channels @@ -277,8 +281,11 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, + size="base", + ): super().__init__( model_name="croma_joint", @@ -286,9 +293,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, - output_dim=768, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -307,7 +316,6 @@ def __init__( self.num_heads = 16 self.patch_size = 8 - self.output_dim = self.embed_dim self.num_patches = int((self.img_size / 8) ** 2) self.s1_channels = 2 # fixed at 2 SAR backscatter channels diff --git a/pangaea/encoders/dofa_encoder.py b/pangaea/encoders/dofa_encoder.py index a5376b37..e636a32f 100644 --- a/pangaea/encoders/dofa_encoder.py +++ b/pangaea/encoders/dofa_encoder.py @@ -180,6 +180,7 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, embed_dim: int, + output_dim: int | list[int], output_layers: int | list[int], wave_list: dict[str, dict[str, float]], download_url: str, @@ -198,9 +199,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -216,9 +219,7 @@ def __init__( self.wave_list[m][bi] for m, b in self.input_bands.items() for bi in b ] - self.norm = norm_layer( - [embed_dim, (self.img_size // patch_size), (self.img_size // patch_size)] - ) + self.norm = norm_layer(self.embed_dim) self.patch_embed = Dynamic_MLP_OFA( wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim @@ -261,6 +262,8 @@ def forward(self, image): output = [] for i, blk in enumerate(self.blocks): x = blk(x) + if i == len(self.blocks) - 1: + x = self.norm(x) if i in self.output_layers: out = ( x[:, 1:] @@ -273,8 +276,7 @@ def forward(self, image): ) .contiguous() ) - if self.use_norm: - out = self.norm(out) + output.append(out) return output diff --git a/pangaea/encoders/gfmswin_encoder.py b/pangaea/encoders/gfmswin_encoder.py index bccbbb8d..49b640a2 100644 --- a/pangaea/encoders/gfmswin_encoder.py +++ b/pangaea/encoders/gfmswin_encoder.py @@ -645,8 +645,10 @@ def __init__( input_size=input_size, embed_dim=embed_dim, output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) @@ -662,7 +664,6 @@ def __init__( self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # self.use_norm = use_norm - self.out_dim = output_dim self.only_output_last = only_output_last # split image into non-overlapping patches self.patch_embed = PatchEmbed( @@ -767,16 +768,9 @@ def forward(self, image): for i, layer in enumerate(self.layers): x = layer(x) B, L, C = x.shape - if not self.only_output_last: - if i in self.output_layers: - out = x.reshape(B, C, int(L**0.5), int(L**0.5)).repeat( - 1, self.out_dim // C, 1, 1 - ) - out = out.view(B, self.out_dim, int(L**0.5), int(L**0.5)) - output.append(out) - else: - if i == self.num_layers - 1: - output.extend([x.reshape(B, C, int(L**0.5), int(L**0.5))] * 4) + if i in self.output_layers: + out = x.transpose(1, 2).view(B, C, int(L**0.5), int(L**0.5)) + output.append(out) # if self.use_norm: # x = self.norm(x) diff --git a/pangaea/encoders/prithvi_encoder.py b/pangaea/encoders/prithvi_encoder.py index e491c1e6..7a380638 100644 --- a/pangaea/encoders/prithvi_encoder.py +++ b/pangaea/encoders/prithvi_encoder.py @@ -45,6 +45,7 @@ def __init__( encoder_weights: str | Path, input_bands: dict[str, list[str]], input_size: int, + output_dim: int | list[int], output_layers: int | list[int], download_url: str, patch_size=16, @@ -63,15 +64,18 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=True, multi_temporal_output=True, + pyramid_output=False, download_url=download_url, ) self.output_layers = output_layers self.img_size = self.input_size + self.tublet_size = tubelet_size if num_frames: self.num_frames = num_frames @@ -193,6 +197,32 @@ def forward(self, image): return output + def enforce_single_temporal(self): + + self.num_frames = 1 + + self.patch_embed = PatchEmbed( + self.input_size, + self.patch_size, + 1, + self.tublet_size, + self.in_chans, + self.embed_dim, + ) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, self.embed_dim), requires_grad=False + ) + + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + class PatchEmbed(nn.Module): """Frames of 2D Images to Patch Embedding The 3D version of timm.models.vision_transformer.PatchEmbed diff --git a/pangaea/encoders/remoteclip_encoder.py b/pangaea/encoders/remoteclip_encoder.py index 9ff5a7ee..a904d1a0 100644 --- a/pangaea/encoders/remoteclip_encoder.py +++ b/pangaea/encoders/remoteclip_encoder.py @@ -351,6 +351,7 @@ def __init__( layers: int, mlp_ratio: float, output_layers: int | list[int], + output_dim: int | list[int], download_url: str, ls_init_value: float | None = None, patch_dropout: float = 0.0, @@ -365,9 +366,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/satlasnet_encoder.py b/pangaea/encoders/satlasnet_encoder.py index 9012cf2b..4b20a0fc 100644 --- a/pangaea/encoders/satlasnet_encoder.py +++ b/pangaea/encoders/satlasnet_encoder.py @@ -418,7 +418,8 @@ def __init__( self, input_bands: dict[str, list[str]], input_size: int, - output_dim: int, + output_dim: int | list[int], + output_layers: int | list[int], model_identifier: str, encoder_weights: str | Path, download_url: str, @@ -433,9 +434,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=768, # will be overwritten by the backbone + output_layers=output_layers, output_dim=output_dim, multi_temporal=False if "_SI_" in model_identifier else True, multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) @@ -450,7 +453,6 @@ def __init__( self.multi_image = model_info["multi_image"] self.backbone_arch = model_info["backbone"] - self.out_dim = self.output_dim self.backbone = self._initialize_backbone( self.in_chans, self.backbone_arch, self.multi_image @@ -542,9 +544,12 @@ def load_encoder_weights(self, logger: Logger) -> None: def forward(self, imgs): # Define forward pass + print(imgs["optical"].shape) + if not isinstance(self.backbone, AggregationBackbone): + print(xx) + imgs["optical"] = imgs["optical"].squeeze(2) - x = imgs["optical"].squeeze(2) - x = self.backbone(x) + x = self.backbone(imgs["optical"]) if self.fpn: x = self.fpn(x) @@ -552,8 +557,6 @@ def forward(self, imgs): output = [] for i in range(len(x)): - B, C, H, W = x[i].shape - out = x[i].repeat(1, self.out_dim // C, 1, 1) - output.append(out) + output.append(x[i]) return output diff --git a/pangaea/encoders/scalemae_encoder.py b/pangaea/encoders/scalemae_encoder.py index 742f0564..3dbe3b44 100644 --- a/pangaea/encoders/scalemae_encoder.py +++ b/pangaea/encoders/scalemae_encoder.py @@ -54,6 +54,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -71,9 +72,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/spectralgpt_encoder.py b/pangaea/encoders/spectralgpt_encoder.py index 98a09a7c..6c9f5526 100644 --- a/pangaea/encoders/spectralgpt_encoder.py +++ b/pangaea/encoders/spectralgpt_encoder.py @@ -45,11 +45,11 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, t_patch_size: int = 3, patch_size: int = 16, - output_dim: int = 768, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -66,9 +66,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, + output_layers=output_layers, output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) @@ -198,21 +200,13 @@ def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]: pos_embed = self.pos_embed[:, :, :] x = x + pos_embed - # reshape to [N, T, L, C] or [N, T*L, C] - requires_t_shape = ( - len(self.blocks) > 0 # support empty decoder - and hasattr(self.blocks[0].attn, "requires_t_shape") - and self.blocks[0].attn.requires_t_shape - ) - if requires_t_shape: - x = x.view([N, T, L, C]) - output = [] for i, blk in enumerate(self.blocks): x = blk(x) if i in self.output_layers: if self.cls_embed: x = x[:, 1:] + x = x.view(N, T, L, C).transpose(2, 3).flatten(1, 2) out = ( x.permute(0, 2, 1) .view( diff --git a/pangaea/encoders/ssl4eo_data2vec_encoder.py b/pangaea/encoders/ssl4eo_data2vec_encoder.py index 5de04436..28ea33cd 100644 --- a/pangaea/encoders/ssl4eo_data2vec_encoder.py +++ b/pangaea/encoders/ssl4eo_data2vec_encoder.py @@ -359,6 +359,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, patch_size: int = 16, @@ -386,9 +387,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/ssl4eo_dino_encoder.py b/pangaea/encoders/ssl4eo_dino_encoder.py index acc16668..308d8bdc 100644 --- a/pangaea/encoders/ssl4eo_dino_encoder.py +++ b/pangaea/encoders/ssl4eo_dino_encoder.py @@ -243,6 +243,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, in_chans: int = 3, patch_size: int = 16, @@ -265,9 +266,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/ssl4eo_mae_encoder.py b/pangaea/encoders/ssl4eo_mae_encoder.py index ae4584af..43d32f3f 100644 --- a/pangaea/encoders/ssl4eo_mae_encoder.py +++ b/pangaea/encoders/ssl4eo_mae_encoder.py @@ -46,6 +46,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -61,9 +62,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_dim=output_dim, + output_layers=output_layers, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/ssl4eo_moco_encoder.py b/pangaea/encoders/ssl4eo_moco_encoder.py index 9f42ef6c..7c74635e 100644 --- a/pangaea/encoders/ssl4eo_moco_encoder.py +++ b/pangaea/encoders/ssl4eo_moco_encoder.py @@ -42,6 +42,7 @@ def __init__( input_size: int, input_bands: dict[str, list[str]], output_layers: int | list[int], + output_dim: int | list[int], download_url: str, embed_dim: int = 1024, patch_size: int = 16, @@ -60,9 +61,11 @@ def __init__( input_bands=input_bands, input_size=input_size, embed_dim=embed_dim, - output_dim=embed_dim, + output_layers=output_layers, + output_dim=output_dim, multi_temporal=False, multi_temporal_output=False, + pyramid_output=False, download_url=download_url, ) diff --git a/pangaea/encoders/unet_encoder.py b/pangaea/encoders/unet_encoder.py index a7089e7e..7c10e52a 100644 --- a/pangaea/encoders/unet_encoder.py +++ b/pangaea/encoders/unet_encoder.py @@ -25,6 +25,7 @@ def __init__( input_bands: dict[str, list[str]], input_size: int, topology: Sequence[int], + output_dim: int | list[int], download_url: str, encoder_weights: str | None = None, ): @@ -35,8 +36,10 @@ def __init__( input_size=input_size, embed_dim=0, output_dim=0, + output_layers=output_layers, multi_temporal=False, # single time frame multi_temporal_output=False, + pyramid_output=True, download_url=download_url, ) diff --git a/pangaea/engine/data_preprocessor.py b/pangaea/engine/data_preprocessor.py index f8b5baa6..ec6b5ebf 100644 --- a/pangaea/engine/data_preprocessor.py +++ b/pangaea/engine/data_preprocessor.py @@ -1,324 +1,105 @@ import logging import math import random +import numbers import numpy as np import torch import torchvision.transforms as T +import torchvision.transforms.functional as TF from torch.utils.data import Dataset -from pangaea.datasets.base import GeoFMDataset -from pangaea.encoders.base import Encoder +from hydra.utils import instantiate +from typing import Callable, Dict, List, Optional, Sequence, Union, Tuple +import copy -class RichDataset(Dataset): - """Dataset wrapper to add preprocessing steps.""" - def __init__(self, dataset: GeoFMDataset, encoder: Encoder): - """Initialize the RichDataset. - - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - self.dataset = dataset - self.encoder = encoder - - # WARNING: Patch to overcome recursive wrapping issues - self.split = dataset.split - self.dataset_name = dataset.dataset_name - self.multi_modal = dataset.multi_modal - self.multi_temporal = dataset.multi_temporal - self.root_path = dataset.root_path - self.classes = dataset.classes - self.num_classes = dataset.num_classes - self.ignore_index = dataset.ignore_index - self.img_size = dataset.img_size - self.bands = dataset.bands - self.distribution = dataset.distribution - self.data_mean = dataset.data_mean - self.data_std = dataset.data_std - self.data_min = dataset.data_min - self.data_max = dataset.data_max - self.download_url = dataset.download_url - self.auto_download = dataset.auto_download - - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - "optical": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset), - "sar": torch.Tensor of shape (C H W) (or (C T H W) if multi-temporal dataset) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - return self.dataset[index] - - def __len__(self) -> int: - """Return the length of the dataset. - - Returns: - int: length of the dataset. - """ - return len(self.dataset) +class BasePreprocessor(): + """Base class for preprocessor.""" + def __init__(self,) -> None: + return -class SegPreprocessor(RichDataset): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the SegPreprocessor for segmentation tasks. - - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - super().__init__(dataset, encoder) - - self.preprocessor = {} - self.preprocessor["optical"] = ( - BandAdaptor(dataset=dataset, encoder=encoder, modality="optical") - if "optical" in dataset.bands.keys() - else None - ) - self.preprocessor["sar"] = ( - BandAdaptor(dataset=dataset, encoder=encoder, modality="sar") - if "sar" in dataset.bands.keys() - else None - ) - for modality in self.encoder.input_bands: - new_stats = self.preprocessor[modality].preprocess_band_statistics( - self.dataset.data_mean[modality], - self.dataset.data_std[modality], - self.dataset.data_min[modality], - self.dataset.data_max[modality], - ) - - self.dataset.data_mean[modality] = new_stats[0] - self.dataset.data_std[modality] = new_stats[1] - self.dataset.data_min[modality] = new_stats[2] - self.dataset.data_max[modality] = new_stats[3] - - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = self.preprocessor[k](v) - - data["target"] = data["target"].long() - return data + raise NotImplementedError + def update_meta(self, meta): + raise NotImplementedError -class RegPreprocessor(SegPreprocessor): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the RegPreprocessor for regression tasks.""" - super().__init__(dataset, encoder) - - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Return a modified item from the dataset. - - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (with T=1 if single timeframe) - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] + def check_dimension(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]): + """check dimension (C, T, H, W) of data""" for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = self.preprocessor[k](v) - data["target"] = data["target"].float() - return data + if len(v.shape) != 4: + raise AssertionError(f"Image dimension must be 4 (C, T, H, W), Got {str(len(v.shape))}") + if len(data["target"].shape) != 2: + raise AssertionError(f"Target dimension must be 2 (H, W), Got {str(len(data['target'].shape))}") -class BandAdaptor: - def __init__(self, dataset: GeoFMDataset, encoder: Encoder, modality: str) -> None: - """Intialize the BandAdaptor. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder unded. - modality (str): image modality. - """ - self.dataset_bands = dataset.bands[modality] - self.input_bands = getattr(encoder.input_bands, modality, []) - - # list of length dataset_n_bands with True if the band is used in the encoder - # and is available in the dataset - self.used_bands_mask = torch.tensor( - [b in self.input_bands for b in self.dataset_bands], dtype=torch.bool - ) - # list of length encoder_n_bands with True if the band is available in the dataset - # and used in the encoder - self.avail_bands_mask = torch.tensor( - [b in self.dataset_bands for b in self.input_bands], dtype=torch.bool - ) - # list of length encoder_n_bands with the index of the band in the dataset - # if the band is available in the dataset and -1 otherwise - self.avail_bands_indices = torch.tensor( - [ - self.dataset_bands.index(b) if b in self.dataset_bands else -1 - for b in self.input_bands - ], - dtype=torch.long, - ) - - # if the encoder requires bands that are not available in the dataset - # then we need to pad the input with zeros - self.need_padded = self.avail_bands_mask.sum() < len(self.input_bands) - self.logger = logging.getLogger() - self.logger.info(f"Adaptor for modality: {modality}") - self.logger.info( - "Available bands in dataset: {}".format( - " ".join(str(b) for b in self.dataset_bands) - ) - ) - self.logger.info( - "Required bands in encoder: {}".format( - " ".join(str(b) for b in self.input_bands) - ) - ) - if self.need_padded: - self.logger.info( - "Unavailable bands {} are padded with zeros".format( - " ".join( - str(b) - for b in np.array(self.input_bands)[ - self.avail_bands_mask.logical_not() - ] - ) - ) - ) + def check_size(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]]): + """check if data size is equal""" + base_shape = data["image"][list(data["image"].keys())[0]].shape - def preprocess_band_statistics( - self, - data_mean: list[float], - data_std: list[float], - data_min: list[float], - data_max: list[float], - ) -> tuple[ - list[float], - list[float], - list[float], - list[float], - ]: - """Filter the statistics to match the available bands. - Args: - data_mean (list[float]): dataset mean (per band in dataset). - data_std (list[float]): dataset std (per band in dataset). - data_min (list[float]): dataset min (per band in dataset). - data_max (list[float]): dataset max (per band in dataset). - Returns: - tuple[ list[float], list[float], list[float], list[float], ]: - dataset mean, std, min, max (per band in encoder). Pad with zeros - if the band is required by the encoder but not included in the dataset. - """ - data_mean = [ - data_mean[i] if i != -1 else 0.0 for i in self.avail_bands_indices.tolist() - ] - data_std = [ - data_std[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] - data_min = [ - data_min[i] if i != -1 else -1.0 for i in self.avail_bands_indices.tolist() - ] - data_max = [ - data_max[i] if i != -1 else 1.0 for i in self.avail_bands_indices.tolist() - ] - return data_mean, data_std, data_min, data_max - - def preprocess_single_timeframe(self, image: torch.Tensor) -> torch.Tensor: - """Apply the preprocessing to a single timeframe, i.e. pad unavailable - bands with zeros if needed to match encoder's bands. - Args: - image (torch.Tensor): input image of shape (dataset_n_bands H W). - Returns: - torch.Tensor: output image of shape (encoder_n_bands H W). - """ - # add padding band at index 0 on the first dim - padded_image = torch.cat([torch.zeros_like(image[0:1]), image], dim=0) - # request all encoder's band. In self.avail_band_indices we have - # -1 for bands not available in the dataset. So we add 1 to get the - # correct index in the padded image (index 0 is the 0-padding band) - return padded_image[self.avail_bands_indices + 1] - - def __call__(self, image: torch.Tensor) -> torch.Tensor: - """Apply the preprocessing to the image. Pad unavailable bands with zeros. - Args: - image (torch.Tensor): image of shape (dataset_n_bands H W). - Returns: - torch.Tensor: output image of shape (encoder_n_bands T H W). - In the case of sigle timeframe, T = 1. - """ - # input of shape (dataset_n_bands T H W) output of shape (encoder_n_bands T H W) - if len(image.shape) == 3: - # Add a time dimension so preprocessing can work on consistent images - image = image.unsqueeze( - 1 - ) # (dataset_n_bands H W)-> (dataset_n_bands 1 H W) - - if image.shape[1] != 1: - final_image = [] - for i in range(image.shape[1]): - final_image.append(self.preprocess_single_timeframe(image[:, i, :, :])) - image = torch.stack(final_image, dim=1) - else: - image = self.preprocess_single_timeframe(image) - - # OUTPUT SHAPE (encoder_n_bands T H W) (T = 1 in the case of single timeframe) - return image + for k, v in data["image"].items(): + if v.shape[1:] != base_shape[1:]: + shape = {k: tuple(v.shape[1:]) for k, v in data["image"].items()} + raise AssertionError(f"Image size (T, H, W) from all modalities must be equal, Got {str(shape)}") + if base_shape[-2:] != data["target"].shape[-2:]: + raise AssertionError(f"Image size and target size (H, W) must be equal, Got {str(tuple(base_shape[-2:]))} and {str(tuple(data['target'].shape[-2:]))}") -class BaseAugment(RichDataset): - """Base class for augmentations.""" - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Augment item. Should call the dataset __getitem__ method which output a dictionary - with the keys "image", "target" and "metadata": - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. +class Preprocessor(BasePreprocessor): + """A series of base preprocessors that preprocess images and targets.""" + def __init__( + self, + preprocessor_cfg, + dataset_cfg, + encoder_cfg + ) -> None: + """Build preprocessors defined in preprocessor_cfg. + Args: + preprocessor_cfg: preprocessor config + dataset_cfg: dataset config + encoder_cfg: encoder config + """ + super().__init__() + # initialize the meta statistics/info of the input data and target encoder + meta = {} + meta['dataset_img_size'] = dataset_cfg['img_size'] + meta['encoder_input_size'] = encoder_cfg['input_size'] + meta['dataset_bands'] = dataset_cfg['bands'] + meta['encoder_bands'] = encoder_cfg['input_bands'] + meta['multi_modal'] = dataset_cfg['multi_modal'] + meta['multi_temporal'] = dataset_cfg['multi_temporal'] + + meta['data_bands'] = dataset_cfg['bands'] + meta['data_img_size'] = dataset_cfg['img_size'] + meta['data_mean'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_mean'].items()} + meta['data_std'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_std'].items()} + meta['data_min'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_min'].items()} + meta['data_max'] = {k: torch.tensor(v) for k, v in dataset_cfg['data_max'].items()} + + meta['ignore_index'] = dataset_cfg['ignore_index'] + meta['class_distribution'] = torch.tensor(dataset_cfg['distribution']) + + self.preprocessor = [] + + # build the preprocessor and update the meta for the next + for preprocess in preprocessor_cfg: + preprocessor = instantiate(preprocess, **meta) + meta = preprocessor.update_meta(meta) + self.preprocessor.append(preprocessor) + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """preprocess images and targets step by step. Args: - index (int): index of data. - + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -330,57 +111,45 @@ def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: "target": torch.Tensor of shape (H W), "metadata": dict}. """ - super().__init__(dataset, encoder) + self.check_dimension(data) + for process in self.preprocessor: + data = process(data) + return data + + +class BandFilter(BasePreprocessor): -class Tile(BaseAugment): def __init__( - self, dataset: GeoFMDataset, encoder: Encoder, min_overlap: int = 0 + self, + **meta ) -> None: - """Initialize the Tiling augmentation. - + """Intialize the BandFilter. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - min_overlap (int, optional): minimum overlap between tiles. Defaults to 0. + meta: statistics/info of the input data and target encoder + data_bands: bands of incoming data + encoder_bands: expected bands by encoder """ - super().__init__(dataset, encoder) - self.min_overlap = min_overlap - # Should be the _largest_ image in the dataset to avoid problems mentioned in __getitem__ - self.input_size = self.dataset.img_size - self.output_size = self.encoder.input_size - if self.output_size == self.input_size: - self.tiles_per_dim = 1 - elif self.output_size > self.input_size: - raise ValueError( - f"Can't tile inputs if dataset.img_size={self.input_size} < encoder.input_size={self.output_size}, use ResizeToEncoder instead." - ) - elif self.min_overlap >= self.input_size: - raise ValueError("min_overlap >= dataset.img_size") - elif self.min_overlap >= self.input_size: - raise ValueError("min_overlap >= encoder.input_size") - else: - self.tiles_per_dim = math.ceil( - (self.input_size - self.min_overlap) - / (self.output_size - self.min_overlap) - ) + super().__init__() - logging.getLogger().info( - f"Tiling {self.input_size}x{self.input_size} input images to {self.tiles_per_dim * self.tiles_per_dim} {self.output_size}x{self.output_size} output images." - ) + self.used_bands_indices = {} - self.h_spacing_cache = [None] * super().__len__() - self.w_spacing_cache = [None] * super().__len__() + for k in meta['data_bands'].keys(): + if k not in meta['encoder_bands'].keys(): + continue + self.used_bands_indices[k] = torch.tensor( + [meta['data_bands'][k].index(b) for b in meta['encoder_bands'][k] if b in meta['data_bands'][k]], dtype=torch.long + ) - self.data_cache = (None, None) + if not self.used_bands_indices: + raise ValueError("No nontrivial input bands after BandFilter!") - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Tiling to the data. + """Filter redundant bands from the data. Args: - index (int): index of data. - + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -392,155 +161,119 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - if self.tiles_per_dim == 1: - return self.dataset[index] - - dataset_index = math.floor(index / (self.tiles_per_dim * self.tiles_per_dim)) - data = self.dataset[dataset_index] - # Calculate tile coordinates - tile_index = index % (self.tiles_per_dim * self.tiles_per_dim) - h_index = math.floor(tile_index / self.tiles_per_dim) - w_index = tile_index % self.tiles_per_dim - # Use the actual image size so we can handle data that's not always uniform. - # This means that min_overlap might not always be respected. - # Also, in case there was insufficient overlap (or tiles_per_dim=1) sepcified, we'll crop the image and lose info. - input_h, input_w = data["image"][next(iter(data["image"].keys()))].shape[-2:] - - # Calculate the sizes of the labeled parts seperately to deal with aliasing when - # tile spacing values are not exact integers - if not self.h_spacing_cache[dataset_index]: - float_spacing = np.linspace( - 0, input_h - self.output_size, self.tiles_per_dim - ) - rounded_spacing = float_spacing.round().astype(int) - labeled_sizes = np.ediff1d(rounded_spacing, to_end=self.output_size) - self.h_spacing_cache[dataset_index] = (rounded_spacing, labeled_sizes) - if not self.w_spacing_cache[dataset_index]: - float_spacing = np.linspace( - 0, input_w - self.output_size, self.tiles_per_dim - ) - rounded_spacing = float_spacing.round().astype(int) - labeled_sizes = np.ediff1d(rounded_spacing, to_end=self.output_size) - self.w_spacing_cache[dataset_index] = (rounded_spacing, labeled_sizes) - h_positions, h_labeled_sizes = self.h_spacing_cache[dataset_index] - w_positions, w_labeled_sizes = self.w_spacing_cache[dataset_index] + data["image"] = {k: data["image"][k][v] for k, v in self.used_bands_indices.items()} - h, w = h_positions[h_index], w_positions[w_index] - h_labeled, w_labeled = h_labeled_sizes[h_index], w_labeled_sizes[w_index] + return data - tiled_data = {"image": {}, "target": None} - tiled_data["image"] = {} - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - tiled_data["image"][k] = v[ - ..., h : h + self.output_size, w : w + self.output_size - ].clone() - - # Place the mesaured part in the middle to help with tiling artefacts - h_label_offset = round((self.output_size - h_labeled) / 2) - w_label_offset = round((self.output_size - w_labeled) / 2) - - # Crop target to size - tiled_data["target"] = data["target"][ - ..., h : h + self.output_size, w : w + self.output_size - ].clone() - - # Ignore overlapping borders - if h_index != 0: - tiled_data["target"][..., 0:h_label_offset, :] = self.dataset.ignore_index - if w_index != 0: - tiled_data["target"][..., 0:w_label_offset] = self.dataset.ignore_index - if h_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - h_label_offset :, :] = ( - self.dataset.ignore_index - ) - if w_index != self.tiles_per_dim - 1: - tiled_data["target"][..., self.output_size - w_label_offset :] = ( - self.dataset.ignore_index - ) + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + for k in list(meta['data_bands'].keys()): + if k not in self.used_bands_indices.keys(): + meta['data_bands'].pop(k, None) + meta['data_mean'].pop(k, None) + meta['data_std'].pop(k, None) + meta['data_min'].pop(k, None) + meta['data_max'].pop(k, None) + else: + meta['data_bands'][k] = [meta['data_bands'][k][i.item()] for i in self.used_bands_indices[k]] + meta['data_mean'][k] = meta['data_mean'][k][self.used_bands_indices[k]] + meta['data_std'][k] = meta['data_std'][k][self.used_bands_indices[k]] + meta['data_min'][k] = meta['data_min'][k][self.used_bands_indices[k]] + meta['data_max'][k] = meta['data_max'][k][self.used_bands_indices[k]] - return tiled_data + return meta - def __len__(self): - return (super().__len__()) * (self.tiles_per_dim * self.tiles_per_dim) +class BandPadding(BasePreprocessor): -class RandomFlip(BaseAugment): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - ud_probability: float, - lr_probability: float, + self, + fill_value: float = 0.0, + **meta ) -> None: - """Initialize the RandomFlip. + """Intialize the BandPadding. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - ud_probability (float): Up/Down augmentation probability. - lr_probability (float): Left/Right augmentation probability. + fill_value (float): fill value for padding. + meta: statistics/info of the input data and target encoder + data_bands: bands of incoming data + encoder_bands: expected bands by encoder """ - super().__init__(dataset, encoder) - self.ud_probability = ud_probability - self.lr_probability = lr_probability + super().__init__() - def __getitem__( - self, index: int + self.fill_value = fill_value + self.data_img_size = meta['data_img_size'] + + self.encoder_bands = meta['encoder_bands'] + self.avail_bands_mask, self.used_bands_indices = {}, {} + for k in meta['encoder_bands'].keys(): + if k in meta['data_bands'].keys(): + self.avail_bands_mask[k] = torch.tensor( + [b in meta['data_bands'][k] for b in meta['encoder_bands'][k]], dtype=torch.bool + ) + self.used_bands_indices[k] = torch.tensor( + [meta['data_bands'][k].index(b) for b in meta['encoder_bands'][k] if b in meta['data_bands'][k]], dtype=torch.long + ) + else: + self.avail_bands_mask[k] = torch.zeros(len(meta['encoder_bands'][k]), dtype=torch.bool) + if not self.used_bands_indices: + raise ValueError("No nontrivial input bands after BandPadding!") + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Random FLIP to the data. - Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. - """ - data = self.dataset[index] - if random.random() < self.lr_probability: - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = torch.fliplr(v) - data["target"] = torch.fliplr(data["target"]) - if random.random() < self.ud_probability: - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = torch.flipud(v) - data["target"] = torch.flipud(data["target"]) + for k in self.avail_bands_mask.keys(): + if k in self.used_bands_indices.keys(): + size = self.avail_bands_mask[k].shape + data["image"][k].shape[1:] + padded_image = torch.full(size, fill_value=self.fill_value, dtype=data["image"][k].dtype) + padded_image[self.avail_bands_mask[k]] = data["image"][k][self.used_bands_indices[k]] + else: + reference = data["image"](list(data["image"].keys())[0]) + size = self.avail_bands_mask[k].shape + reference.shape[1:] + padded_image = torch.full(size, fill_value=self.fill_value, dtype=reference.dtype) + + data["image"][k] = padded_image return data - -class GammaAugment(BaseAugment): + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_bands'] = meta['encoder_bands'] + for k in self.avail_bands_mask.keys(): + size = self.avail_bands_mask[k].shape + meta['data_mean'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + meta['data_std'][k] = torch.ones(size, dtype=torch.float) + meta['data_min'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + meta['data_max'][k] = torch.full(size, fill_value=self.fill_value, dtype=torch.float) + if self.used_bands_indices[k] is not None: + meta['data_mean'][k][self.avail_bands_mask[k]] = meta['data_mean'][k][self.used_bands_indices[k]] + meta['data_std'][k][self.avail_bands_mask[k]] = meta['data_std'][k][self.used_bands_indices[k]] + meta['data_min'][k][self.avail_bands_mask[k]] = meta['data_min'][k][self.used_bands_indices[k]] + meta['data_max'][k][self.avail_bands_mask[k]] = meta['data_max'][k][self.used_bands_indices[k]] + return meta + + +class NormalizeMeanStd(BasePreprocessor): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - probability: float, - gamma_range: float, + self, + **meta, ) -> None: - """Initialize the GammaAugment. + """Initialize the NormalizeMeanStd. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - probability (float): probability of applying the augmentation. - gamma_range (float): gamma range. + meta: statistics/info of the input data and target encoder + data_mean: global mean of incoming data + data_std: global std of incoming data """ - super().__init__(dataset, encoder) - self.probability = probability - self.gamma_range = gamma_range + super().__init__() + + self.data_mean = meta['data_mean'] + self.data_std = meta['data_std'] - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Gamma Augment to the data. + """Apply Mean/Std Normalization to the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -553,43 +286,42 @@ def __getitem__( "metadata": dict}. """ - data = self.dataset[index] - # WARNING: Test this bit of code - if random.random() < self.probability: - for k, v in data["image"].items() and k in self.encoder.input_bands: - data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range)) + for k in self.data_mean.keys(): + data["image"][k].sub_(self.data_mean[k].view(-1, 1, 1, 1)).div_(self.data_std[k].view(-1, 1, 1, 1)) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_mean'] = {k: torch.zeros_like(v) for k, v in meta['data_mean'].items()} + meta['data_std'] = {k: torch.ones_like(v) for k, v in meta['data_std'].items()} + meta['data_min'] = {k: (v - meta['data_mean'][k]) / meta['data_std'][k] for k, v in meta['data_min'].items()} + meta['data_max'] = {k: (v - meta['data_mean'][k]) / meta['data_std'][k] for k, v in meta['data_max'].items()} + + return meta -class NormalizeMeanStd(BaseAugment): + +class NormalizeMinMax(BasePreprocessor): def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder + self, + **meta, ) -> None: """Initialize the NormalizeMeanStd. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. + meta: statistics/info of the input data and target encoder + data_min: global maximum value of incoming data + data_sax: global minimum value of incoming data """ - super().__init__(dataset, encoder) - self.data_mean_tensors = {} - self.data_std_tensors = {} - # Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use - for modality in self.encoder.input_bands: - self.data_mean_tensors[modality] = torch.tensor( - self.dataset.data_mean[modality] - ).reshape((-1, 1, 1, 1)) - self.data_std_tensors[modality] = torch.tensor( - self.dataset.data_std[modality] - ).reshape((-1, 1, 1, 1)) - - def __getitem__( - self, index: int + super().__init__() + + self.data_min = meta['data_min'] + self.data_max = meta['data_max'] + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: """Apply Mean/Std Normalization to the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -601,49 +333,87 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - for modality in self.encoder.input_bands: - data["image"][modality] = ( - data["image"][modality] - self.data_mean_tensors[modality] - ) / self.data_std_tensors[modality] + + for k in self.data_min.keys(): + size = (-1,) + data["image"][k].shape[1:] + data["image"][k].sub_(self.data_min[k].view(size)).div_((self.data_max[k]-self.data_min[k]).view(size)) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_mean'] = {k: (v - meta['data_min'][k]) / (meta['data_max'][k] - meta['data_min'][k]) for k, v in meta['data_mean'].items()} + meta['data_std'] = {k: v / (meta['data_max'][k] - meta['data_min'][k]) for k, v in meta['data_std'].items()} + meta['data_min'] = {k: torch.zeros_like(v) for k, v in meta['data_mean'].items()} + meta['data_max'] = {k: torch.ones_like(v) for k, v in meta['data_std'].items()} -class NormalizeMinMax(BaseAugment): + return meta + + +class RandomCrop(BasePreprocessor): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - data_min: torch.Tensor, - data_max: torch.Tensor, + size: int | Sequence[int], + pad_if_needed: bool = False, + **meta ) -> None: - """Apply Min/Max Normalization to scale the data to the range [data_min, data_max]. + """Initialize the RandomCrop preprocessor. + Args: + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding + """ + super().__init__() + + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.pad_if_needed = pad_if_needed + self.pad_value = meta['data_mean'] + self.ignore_index = meta['ignore_index'] + + def get_params(self, data: dict) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - data_min (torch.Tensor): data_min scalar tensor. - data_max (torch.Tensor): data_max scalar tensor. + data (dict): input data. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - super().__init__(dataset, encoder) - self.normalizers = {} - self.data_min_tensors = {} - self.data_max_tensors = {} - self.min = data_min - self.max = data_max - for modality in self.encoder.input_bands: - self.data_min_tensors[modality] = torch.tensor( - self.dataset.data_min[modality] - ).reshape((-1, 1, 1, 1)) - self.data_max_tensors[modality] = torch.tensor( - self.dataset.data_max[modality] - ).reshape((-1, 1, 1, 1)) - - def __getitem__( - self, index: int + h, w = data["image"][list(data["image"].keys())[0]].shape[-2:] + th, tw = self.size + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return i, j, th, tw + + def check_pad(self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]], + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + _, t, height, width = data["image"][list(data["image"].keys())[0]].shape + + if height < self.size[0] or width < self.size[1]: + pad_img = max(self.size[0] - height, 0), max(self.size[1] - width, 0) + height, width = height + 2 * pad_img[0], width + 2 * pad_img[1] + for k, v in data["image"].items(): + padded_img = self.pad_value[k].reshape(-1, 1, 1, 1).repeat(1, t, height, width) + padded_img[:, :, pad_img[0]:-pad_img[0], pad_img[1]:-pad_img[1]] = v + data["image"][k] = padded_img + + data["target"] = TF.pad(data["target"], padding=padded_img, fill=self.ignore_index, padding_mode='constant') + + return data + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply Min/Max Normalization to the data. + """Random crop the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -655,149 +425,142 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ + self.check_size(data) + + if self.pad_if_needed: + data = self.check_pad(data) + + i, j, h, w = self.get_params(data=data) + + for k, v in data["image"].items(): + data["image"][k] = TF.crop(v, i, j, h, w) + + data["target"] = TF.crop(data["target"], i, j, h, w) - data = self.dataset[index] - for modality in self.encoder.input_bands: - data["image"][modality] = ( - (data["image"][modality] - self.data_min_tensors[modality]) - * (self.max - self.min) - - self.min - ) / self.data_max_tensors[modality] - return data - + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta -class ColorAugmentation(BaseAugment): + +class RandomCropToEncoder(RandomCrop): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - brightness: float = 0, - contrast: float = 0, - clip: bool = False, - br_probability: float = 0, - ct_probability: float = 0, + pad_if_needed: bool = False, + **meta ) -> None: - """Initialize the ColorAugmentation. + """Initialize the RandomCropToEncoder preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - brightness (float, optional): brightness parameter. Defaults to 0. - contrast (float, optional): contrast. Defaults to 0. - clip (bool, optional): clip parameter. Defaults to False. - br_probability (float, optional): brightness augmentation probability. - Defaults to 0. - ct_probability (float, optional): contrast augmentation probability. - Defaults to 0. + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) - self.brightness = brightness - self.contrast = contrast - self.clip = clip - self.br_probability = br_probability - self.ct_probability = ct_probability - - def adjust_brightness( - self, image: torch.Tensor, factor: float | torch.Tensor, clip_output: bool - ) -> torch.Tensor: - """Adjust the brightness of the image. + size = meta['encoder_input_size'] + super().__init__( + size, pad_if_needed, **meta + ) + +class FocusRandomCrop(RandomCrop): + def __init__( + self, + size: int, + pad_if_needed: bool = False, + **meta + ) -> None: + """Initialize the FocusRandomCrop preprocessor. Args: - image (torch.Tensor): input image of shape (C T H W) (T=1 if single timeframe). - factor (float | torch.Tensor): adjustment factor. - clip_output (bool): whether to clip the output. - Returns: - torch.Tensor: output image of shape (C T H W) (T=1 if single timeframe). + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - if isinstance(factor, float): - factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype) - while len(factor.shape) != len(image.shape): - factor = factor[..., None] - - img_adjust = image + factor - if clip_output: - img_adjust = img_adjust.clamp(min=-1.0, max=1.0) + super().__init__( + size, + pad_if_needed, + **meta) - return img_adjust + def get_params(self, data: dict) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. - def adjust_contrast( - self, image: torch.Tensor, factor: torch.Tensor | float, clip_output: bool - ) -> torch.Tensor: - """Adjust the contrast of the image. Args: - image (torch.Tensor): image input of shape (C T H W) (T=1 if single timeframe). - factor (torch.Tensor | float): augmentation factor. - clip_output (bool): whether to clip the output. + data (dict): input data. Returns: - torch.Tensor: output image of shape (C T H W) (T=1 if single timeframe). + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - if isinstance(factor, float): - factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype) - while len(factor.shape) != len(image.shape): - factor = factor[..., None] - assert factor >= 0, "Contrast factor must be positive" - img_adjust = image * factor - if clip_output: - img_adjust = img_adjust.clamp(min=-1.0, max=1.0) + h, w = data["target"].shape + th, tw = self.size - return img_adjust + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") - def __getitem__( - self, index: int - ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Apply ColorAugmentation to the data. + if w == tw and h == th: + return 0, 0, h, w + + valid_map = data["target"] != self.ignore_index + idx = torch.arange(0, h*w)[valid_map.flatten()] + sample = idx[random.randint(0, idx.shape[0] - 1)] + y, x = sample // w, sample % w + + i = random.randint(max(0, y - th), min(y, h - th + 1)) + j = random.randint(max(0, x - tw), min(x, w - tw + 1)) + + return i, j, th, tw + + +class FocusRandomCropToEncoder(FocusRandomCrop): + def __init__( + self, + pad_if_needed: bool = False, + **meta + ) -> None: + """Initialize the FocusRandomCropToEncoder preprocessor. Args: - index (int): index of data. - Returns: - dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format - {"image": - { - encoder_modality_1: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - ... - encoder_modality_N: torch.Tensor of shape (C T H W) (T=1 if single timeframe), - }, - "target": torch.Tensor of shape (H W), - "metadata": dict}. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - data = self.dataset[index] - for k, _ in data["image"].items(): - if k in self.encoder.input_bands: - brightness = random.uniform(-self.brightness, self.brightness) - if random.random() < self.br_probability: - data["image"][k] = self.adjust_brightness( - data["image"][k], brightness, self.clip - ) - - for k, _ in data["image"].items(): - if k in self.encoder.input_bands: - if random.random() < self.ct_probability: - contrast = random.uniform(1 - self.contrast, 1 + self.contrast) - data["image"][k] = self.adjust_contrast( - data["image"][k], contrast, self.clip - ) - - return data + size = meta['encoder_input_size'] + super().__init__( + size, pad_if_needed, **meta + ) -class Resize(BaseAugment): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder, size: int) -> None: - """Initialize the Resize augmentation. +class ImportanceRandomCrop(RandomCrop): + def __init__( + self, + size, + pad_if_needed: bool = False, + num_trials: int = 10, + **meta + ) -> None: + """Initialize the FocusRandomCrop preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): size of the output image. + size (int): crop size. + pad_if_needed (bool, optional): whether to pad. Defaults to False. + num_trials (int, optional): number of trials. Defaults to 10. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) + super().__init__(size, pad_if_needed, **meta) + + self.num_trials = num_trials + self.class_weight = 1 / meta["class_distribution"] - self.size = (size, size) - def __getitem__( - self, index: int + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Resize the data. + """Random crop the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -809,70 +572,93 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.Resize(self.size, interpolation=T.InterpolationMode.BILINEAR, antialias=True)(v) - - if data["target"].ndim == 2: - data["target"] = data["target"].unsqueeze(0) - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) - data["target"] = data["target"].squeeze(0) + self.check_size(data) + + if self.pad_if_needed: + data, height, width = self.check_pad(data) else: - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + _, _, height, width = data["image"][list(data["image"].keys())[0]].shape - return data + valid = data["target"] != self.ignore_index + weight_map = torch.full(size=data["target"].shape, fill_value=1e-6, dtype=torch.float) + weight_map[valid] = self.class_weight[data["target"][valid]] + crop_candidates = [self.get_params(data) for _ in range(self.num_trials)] + crop_weights = [weight_map[i:i+h, j:j+w].sum().item()/(h*w) for i, j, h, w in crop_candidates] + crop_weights = np.array(crop_weights) + crop_weights = crop_weights / crop_weights.sum() -class ResizeToEncoder(Resize): - def __init__(self, dataset: GeoFMDataset, encoder: Encoder) -> None: - """Initialize the ResizeToEncoder augmentation. - Resize input data to the encoder input size. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - """ - super().__init__(dataset, encoder, encoder.input_size) + crop_idx = np.random.choice(self.num_trials, p=crop_weights) + i, j, h, w = crop_candidates[crop_idx] + + for k, v in data["image"].items(): + data["image"][k] = TF.crop(v, i, j, h, w) + data["target"] = TF.crop(data["target"], i, j, h, w) -class RandomCrop(BaseAugment): + return data + + +class ImportanceRandomCropToEncoder(ImportanceRandomCrop): def __init__( self, - dataset: GeoFMDataset, - encoder: Encoder, - size: int, - padding: str | None = None, pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", + num_trials: int = 10, + **meta ) -> None: - """Initialize the RandomCrop augmentation. + """Initialize the FocusRandomCrop preprocessor. Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". + num_trials (int, optional): number of trials. Defaults to 10. + meta: statistics/info of the input data and target encoder + data_mean: global mean value of incoming data for potential padding + ignore_index: ignore index for potential padding """ - super().__init__(dataset, encoder) - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode + size = meta["encoder_input_size"] + super().__init__(size, pad_if_needed, num_trials, **meta) + + +class Resize(BasePreprocessor): + def __init__( + self, + size: int | Sequence[int], + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the Resize preprocessor. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + super().__init__() + + if not isinstance(size, (int, Sequence)): + raise TypeError(f"Size should be int or sequence. Got {type(size)}") + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def __getitem__( - self, index: int + if isinstance(interpolation, int): + interpolation = TF._interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + self.resize_target = resize_target + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - """Random crop the data. + """Resize the data. Args: - index (int): index of data. + data (dict): input data. Returns: dict[str, torch.Tensor | dict[str, torch.Tensor]]: output dictionary following the format {"image": @@ -884,145 +670,186 @@ def __getitem__( "target": torch.Tensor of shape (H W), "metadata": dict}. """ - data = self.dataset[index] - # Use the first image to determine parameters - i, j, h, w = T.RandomCrop.get_params( - data["image"][list(data["image"].keys())[0]], - output_size=(self.size, self.size), - ) for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.functional.crop(v, i, j, h, w) - data["target"] = T.functional.crop(data["target"], i, j, h, w) + data["image"][k] = TF.resize(data["image"][k], self.size, interpolation=self.interpolation, antialias=self.antialias) + + if self.resize_target: + if torch.is_floating_point(data["target"]): + data["target"] = TF.resize(data["target"].unsqueeze(0), size=self.size, interpolation=T.InterpolationMode.BILINEAR).squeeze(0) + else: + data["target"] = TF.resize(data["target"].unsqueeze(0), size=self.size, interpolation=T.InterpolationMode.NEAREST).squeeze(0) return data + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta -class RandomCropToEncoder(RandomCrop): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - ) -> None: - """Initialize the RandomCropToEncoder augmentation. - Apply RandomCrop to the encoder input size. - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad or not. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - """ - size = encoder.input_size - super().__init__( - dataset, encoder, size, padding, pad_if_needed, fill, padding_mode - ) +class ResizeToEncoder(Resize): + def __init__(self, + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the ResizeToEncoder preprocessor. + Args: + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + size = meta['encoder_input_size'] + super().__init__(size, interpolation, antialias, resize_target, **meta) + + +class RandomResizedCrop(BasePreprocessor): + def __init__(self, + size: int | Sequence[int], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (0.75, 1.3333333333333333), + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + )-> None: + """Initialize the RandomResizedCrop preprocessor. + Args: + size (int): crop size. + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + raise ValueError("Scale and ratio should be of kind (min, max)") + + if isinstance(interpolation, int): + interpolation = TF._interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + self.scale = scale + self.ratio = ratio + self.resize_target = resize_target + + @staticmethod + def get_params(input_size: Tuple[int, int], scale: Tuple[float, float], ratio: Tuple[float, float]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image or Tensor): Input image. + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped -class ImportanceRandomCrop(BaseAugment): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - size: int, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - n_crops: int = 10, - ) -> None: - """Initialize the ImportanceRandomCrop. + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + height, width = input_size + area = height * width + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__( + self, data: dict[str, torch.Tensor | dict[str, torch.Tensor]] + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - n_crops (int, optional): number of crops. Defaults to 10. - """ - super().__init__(dataset, encoder) - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill = fill - self.padding_mode = padding_mode - self.n_crops = n_crops - - def __getitem__(self, index): - data = self.dataset[index] - - # dataset needs to provide a weighting layer - weight = torch.zeros_like(data["target"]).float() - for i, freq in enumerate(self.distribution): - weight[data["target"] == i] = 1 - freq - weight[data["target"] == self.ignore_index] = 1e-6 - # candidates for random crop - crop_candidates, crop_weights = [], [] - for _ in range(self.n_crops): - i, j, h, w = T.RandomCrop.get_params( - data["image"][ - list(data["image"].keys())[0] - ], # Use the first image to determine parameters - output_size=(self.size, self.size), - ) - crop_candidates.append((i, j, h, w)) + self.check_size(data) - crop_weight = T.functional.crop(weight, i, j, h, w) - crop_weights.append(torch.sum(crop_weight).item()) + _, t, h_img, w_img = data["image"][list(data["image"].keys())[0]].shape - crop_weights = np.array(crop_weights) - crop_weights = crop_weights / np.sum(crop_weights) - crop_idx = np.random.choice(self.n_crops, p=crop_weights) - i, j, h, w = crop_candidates[crop_idx] + i, j, h, w = self.get_params((h_img, w_img), self.scale, self.ratio) for k, v in data["image"].items(): - if k in self.encoder.input_bands: - data["image"][k] = T.functional.crop(v, i, j, h, w) - data["target"] = T.functional.crop(data["target"], i, j, h, w) + data["image"][k] = TF.resized_crop(data["image"][k], i, j, h, w, self.size, self.interpolation, antialias=self.antialias) - return data + if self.resize_target: + if torch.is_floating_point(data["target"]): + data["target"] = TF.resized_crop(data["target"].unsqueeze(0), i, j, h, w, self.size, T.InterpolationMode.BILINEAR).squeeze(0) + else: + data["target"] = TF.resized_crop(data["target"].unsqueeze(0), i, j, h, w, self.size, T.InterpolationMode.NEAREST).squeeze(0) + else: + data["target"] = TF.crop(data["target"], i, j, h, w) -class ImportanceRandomCropToEncoder(ImportanceRandomCrop): - def __init__( - self, - dataset: GeoFMDataset, - encoder: Encoder, - padding: str | None = None, - pad_if_needed: bool = False, - fill: int = 0, - padding_mode: str = "constant", - n_crops: int = 10, - ) -> None: - """Initialize the ImportanceRandomCropToEncoder. - Initialize the ImportanceRandomCrop augmentation to the encoder input size. + return data - Args: - dataset (GeoFMDataset): dataset used. - encoder (Encoder): encoder used. - size (int): crop size. - padding (str | None, optional): image padding. Defaults to None. - pad_if_needed (bool, optional): whether to pad. Defaults to False. - fill (int, optional): value for padding. Defaults to 0. - padding_mode (str, optional): padding mode. Defaults to "constant". - n_crops (int, optional): number of crops. Defaults to 10. - """ - size = encoder.input_size - super().__init__( - dataset=dataset, - encoder=encoder, - size=size, - padding=padding, - pad_if_needed=pad_if_needed, - fill=fill, - padding_mode=padding_mode, - n_crops=n_crops, - ) + def update_meta(self, meta): + """Tracking the meta statistics/info for next processor.""" + meta['data_img_size'] = self.size[0] + return meta + +class RandomResizedCropToEncoder(RandomResizedCrop): + def __init__(self, + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (0.75, 1.3333333333333333), + interpolation=T.InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + resize_target: bool = True, + **meta + ) -> None: + """Initialize the RandomResizedCropToEncoder preprocessor. + Args: + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + antialias (bool, optional): Whether to apply antialiasing. + resize_target (bool, optional): Whether to resize the target + meta: statistics/info of the input data and target encoder + """ + size = meta['encoder_input_size'] + super().__init__(size, scale, ratio, interpolation, antialias, resize_target, **meta) + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size \ No newline at end of file diff --git a/pangaea/engine/evaluator.py b/pangaea/engine/evaluator.py index bfe6824c..5e6ec7ea 100644 --- a/pangaea/engine/evaluator.py +++ b/pangaea/engine/evaluator.py @@ -2,6 +2,8 @@ import os import time from pathlib import Path +import math +import wandb import torch import torch.nn.functional as F @@ -38,33 +40,34 @@ class Evaluator: """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ) -> None: + self.rank = int(os.environ["RANK"]) self.val_loader = val_loader self.logger = logging.getLogger() self.exp_dir = exp_dir self.device = device + self.inference_mode = inference_mode + self.sliding_inference_batch = sliding_inference_batch self.classes = self.val_loader.dataset.classes self.split = self.val_loader.dataset.split self.ignore_index = self.val_loader.dataset.ignore_index self.num_classes = len(self.classes) self.max_name_len = max([len(name) for name in self.classes]) - self.use_wandb = use_wandb - - if use_wandb: - import wandb - self.wandb = wandb + self.use_wandb = use_wandb def evaluate( - self, - model: torch.nn.Module, - model_name: str, - model_ckpt_path: str | Path | None = None, + self, + model: torch.nn.Module, + model_name: str, + model_ckpt_path: str | Path | None = None, ) -> None: raise NotImplementedError @@ -77,6 +80,53 @@ def compute_metrics(self): def log_metrics(self, metrics): pass + @staticmethod + def sliding_inference(model, img, input_size, output_shape=None, stride=None, max_batch=None): + b, c, t, height, width = img[list(img.keys())[0]].shape + + if stride is None: + h = int(math.ceil(height / input_size)) + w = int(math.ceil(width / input_size)) + else: + h = math.ceil((height - input_size) / stride) + 1 + w = math.ceil((width - input_size) / stride) + 1 + + h_grid = torch.linspace(0, height - input_size, h).round().long() + w_grid = torch.linspace(0, width - input_size, w).round().long() + num_crops_per_img = h * w + + for k, v in img.items(): + img_crops = [] + for i in range(h): + for j in range(w): + img_crops.append(v[:, :, :, h_grid[i]:h_grid[i] + input_size, w_grid[j]:w_grid[j] + input_size]) + img[k] = torch.cat(img_crops, dim=0) + + pred = [] + max_batch = max_batch if max_batch is not None else b * num_crops_per_img + batch_num = int(math.ceil(b * num_crops_per_img / max_batch)) + for i in range(batch_num): + img_ = {k: v[max_batch * i: min(max_batch * i + max_batch, b * num_crops_per_img)] for k, v in img.items()} + pred_ = model.forward(img_, output_shape=(input_size, input_size)) + pred.append(pred_) + pred = torch.cat(pred, dim=0) + pred = pred.view(num_crops_per_img, b, -1, input_size, input_size).transpose(0, 1) + + merged_pred = torch.zeros((b, pred.shape[2], height, width), device=pred.device) + pred_count = torch.zeros((b, height, width), dtype=torch.long, device=pred.device) + for i in range(h): + for j in range(w): + merged_pred[:, :, h_grid[i]:h_grid[i] + input_size, + w_grid[j]:w_grid[j] + input_size] += pred[:, h * i + j] + pred_count[:, h_grid[i]:h_grid[i] + input_size, + w_grid[j]:w_grid[j] + input_size] += 1 + + merged_pred = merged_pred / pred_count.unsqueeze(1) + if output_shape is not None: + merged_pred = F.interpolate(merged_pred, size=output_shape, mode="bilinear") + + return merged_pred + class SegEvaluator(Evaluator): """ @@ -100,13 +150,15 @@ class SegEvaluator(Evaluator): """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ): - super().__init__(val_loader, exp_dir, device, use_wandb) + super().__init__(val_loader, exp_dir, device, inference_mode, sliding_inference_batch, use_wandb) @torch.no_grad() def evaluate(self, model, model_name='model', model_ckpt_path=None): @@ -129,11 +181,19 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): ) for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): + image, target = data["image"], data["target"] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) - logits = model(image, output_shape=target.shape[-2:]) + if self.inference_mode == "sliding": + input_size = model.module.encoder.input_size + logits = self.sliding_inference(model, image, input_size, output_shape=target.shape[-2:], + max_batch=self.sliding_inference_batch) + elif self.inference_mode == "whole": + logits = model(image, output_shape=target.shape[-2:]) + else: + raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) if logits.shape[1] == 1: pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1) else: @@ -141,14 +201,14 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): valid_mask = target != self.ignore_index pred, target = pred[valid_mask], target[valid_mask] count = torch.bincount( - (pred * self.num_classes + target), minlength=self.num_classes**2 + (pred * self.num_classes + target), minlength=self.num_classes ** 2 ) confusion_matrix += count.view(self.num_classes, self.num_classes) torch.distributed.all_reduce( confusion_matrix, op=torch.distributed.ReduceOp.SUM ) - metrics = self.compute_metrics(confusion_matrix) + metrics = self.compute_metrics(confusion_matrix.cpu()) self.log_metrics(metrics) used_time = time.time() - t @@ -200,16 +260,16 @@ def log_metrics(self, metrics): def format_metric(name, values, mean_value): header = f"------- {name} --------\n" metric_str = ( - "\n".join( - c.ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % num) - for c, num in zip(self.classes, values) - ) - + "\n" + "\n".join( + c.ljust(self.max_name_len, " ") + "\t{:>7}".format("%.3f" % num) + for c, num in zip(self.classes, values) + ) + + "\n" ) mean_str = ( - "-------------------\n" - + "Mean".ljust(self.max_name_len, " ") - + "\t{:>7}".format("%.3f" % mean_value) + "-------------------\n" + + "Mean".ljust(self.max_name_len, " ") + + "\t{:>7}".format("%.3f" % mean_value) ) return header + metric_str + mean_str @@ -231,7 +291,7 @@ def format_metric(name, values, mean_value): self.logger.info(macc_str) if self.use_wandb: - self.wandb.log( + wandb.log( { f"{self.split}_mIoU": metrics["mIoU"], f"{self.split}_mF1": metrics["mF1"], @@ -274,18 +334,20 @@ class RegEvaluator(Evaluator): """ def __init__( - self, - val_loader: DataLoader, - exp_dir: str | Path, - device: torch.device, - use_wandb: bool, + self, + val_loader: DataLoader, + exp_dir: str | Path, + device: torch.device, + inference_mode: str = 'sliding', + sliding_inference_batch: int = None, + use_wandb: bool = False, ): - super().__init__(val_loader, exp_dir, device, use_wandb) + super().__init__(val_loader, exp_dir, device, inference_mode, sliding_inference_batch, use_wandb) @torch.no_grad() def evaluate(self, model, model_name='model', model_ckpt_path=None): t = time.time() - + if model_ckpt_path is not None: model_dict = torch.load(model_ckpt_path, map_location=self.device) model_name = os.path.basename(model_ckpt_path).split('.')[0] @@ -299,17 +361,29 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): model.eval() tag = f'Evaluating {model_name} on {self.split} set' - mse = 0.0 + + mse = torch.zeros(1, device=self.device) + for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)): image, target = data['image'], data['target'] image = {k: v.to(self.device) for k, v in image.items()} target = target.to(self.device) - logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1) - loss = F.mse_loss(logits, target) - mse +=loss - mse_avg = mse / len(self.val_loader) - metrics = {"MSE" : mse_avg.item(), "RMSE" : torch.sqrt(mse_avg).item()} + if self.inference_mode == "sliding": + input_size = model.module.encoder.input_size + logits = self.sliding_inference(model, image, input_size, output_shape=target.shape[-2:], + max_batch=self.sliding_inference_batch) + elif self.inference_mode == "whole": + logits = model(image, output_shape=target.shape[-2:]).squeeze(dim=1) + else: + raise NotImplementedError((f"Inference mode {self.inference_mode} is not implemented.")) + + mse += F.mse_loss(logits, target, reduction='sum') + + torch.distributed.all_reduce(mse, op=torch.distributed.ReduceOp.SUM) + mse = mse / len(self.val_loader.dataset) + + metrics = {"MSE": mse.item(), "RMSE": torch.sqrt(mse).item()} self.log_metrics(metrics) used_time = time.time() - t @@ -322,9 +396,9 @@ def __call__(self, model, model_name='model', model_ckpt_path=None): def log_metrics(self, metrics): header = "------- MSE and RMSE --------\n" - mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE'])+'\n' + mse = "-------------------\n" + 'MSE \t{:>7}'.format('%.3f' % metrics['MSE']) + '\n' rmse = "-------------------\n" + 'RMSE \t{:>7}'.format('%.3f' % metrics['RMSE']) - self.logger.info(header+mse+rmse) + self.logger.info(header + mse + rmse) if self.use_wandb: - self.wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) + wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]}) diff --git a/pangaea/engine/trainer.py b/pangaea/engine/trainer.py index d10cec40..5aebcdae 100644 --- a/pangaea/engine/trainer.py +++ b/pangaea/engine/trainer.py @@ -141,6 +141,7 @@ def train_one_epoch(self, epoch: int) -> None: image, target = data["image"], data["target"] image = {modality: value.to(self.device) for modality, value in image.items()} target = target.to(self.device) + self.training_stats["data_time"].update(time.time() - end_time) with torch.autocast( @@ -151,17 +152,19 @@ def train_one_epoch(self, epoch: int) -> None: self.optimizer.zero_grad() - if not torch.isnan(loss): - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - self.training_stats['loss'].update(loss.item()) - with torch.no_grad(): - self.compute_logging_metrics(logits, target) - if (batch_idx + 1) % self.log_interval == 0: - self.log(batch_idx + 1, epoch) - else: - self.logger.warning("Skip batch {} because of nan loss".format(batch_idx + 1)) + if not torch.isfinite(loss): + raise FloatingPointError( + f"Rank {self.rank} got infinite/NaN loss at batch {batch_idx} of epoch {epoch}!" + ) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + self.training_stats['loss'].update(loss.item()) + with torch.no_grad(): + self.compute_logging_metrics(logits, target) + if (batch_idx + 1) % self.log_interval == 0: + self.log(batch_idx + 1, epoch) self.lr_scheduler.step() @@ -216,6 +219,7 @@ def save_model( checkpoint (dict[str, dict | int] | None, optional): already prepared checkpoint dict. Defaults to None. """ if self.rank != 0: + torch.distributed.barrier() return checkpoint = self.get_checkpoint(epoch) if checkpoint is None else checkpoint suffix = "_best" if is_best else "_final" if is_final else "" @@ -224,6 +228,8 @@ def save_model( self.logger.info( f"Epoch {epoch} | Training checkpoint saved at {checkpoint_path}" ) + torch.distributed.barrier() + return def load_model(self, resume_path: str | pathlib.Path) -> None: """Load model from the checkpoint. @@ -307,9 +313,8 @@ def log(self, batch_idx: int, epoch) -> None: left_batch_all = ( self.batch_per_epoch * (self.n_epochs - epoch - 1) + left_batch_this_epoch ) - left_eval_times = ( - self.n_epochs + 0.5 - ) // self.eval_interval - self.training_stats["eval_time"].count + left_eval_times = ((self.n_epochs - 0.5) // self.eval_interval + 2 + - self.training_stats["eval_time"].count) left_time_this_epoch = sec_to_hm( left_batch_this_epoch * self.training_stats["batch_time"].avg ) diff --git a/pangaea/run.py b/pangaea/run.py index a5e3687d..8ab6518c 100644 --- a/pangaea/run.py +++ b/pangaea/run.py @@ -26,7 +26,7 @@ seed_worker, ) from pangaea.utils.subset_sampler import get_subset_indices -from pangaea.datasets.base import GeoFMSubset +from pangaea.datasets.base import GeoFMSubset, GeoFMDataset, RawGeoFMDataset def get_exp_info(hydra_config: HydraConf) -> str: @@ -121,12 +121,6 @@ def main(cfg: DictConfig) -> None: logger.info("The experiment is stored in %s\n" % exp_dir) logger.info(f"Device used: {device}") - # get datasets - train_dataset: Dataset = instantiate(cfg.dataset, split="train") - val_dataset: Dataset = instantiate(cfg.dataset, split="val") - test_dataset: Dataset = instantiate(cfg.dataset, split="test") - logger.info("Built {} dataset.".format(cfg.dataset.dataset_name)) - encoder: Encoder = instantiate(cfg.encoder) encoder.load_encoder_weights(logger) logger.info("Built {}.".format(encoder.model_name)) @@ -151,15 +145,20 @@ def main(cfg: DictConfig) -> None: # training if train_run: + # get preprocessor + train_preprocessor = instantiate(cfg.preprocessing.train, dataset_cfg=cfg.dataset, encoder_cfg=cfg.encoder, + _recursive_=False) + val_preprocessor = instantiate(cfg.preprocessing.val, dataset_cfg=cfg.dataset, encoder_cfg=cfg.encoder, + _recursive_=False) + + # get datasets + raw_train_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="train") + raw_val_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="val") + train_dataset = GeoFMDataset(raw_train_dataset, train_preprocessor) + val_dataset = GeoFMDataset(raw_val_dataset, val_preprocessor) + + logger.info("Built {} dataset.".format(cfg.dataset.dataset_name)) - for preprocess in cfg.preprocessing.train: - train_dataset: Dataset = instantiate( - preprocess, dataset=train_dataset, encoder=encoder - ) - for preprocess in cfg.preprocessing.test: - val_dataset: Dataset = instantiate( - preprocess, dataset=val_dataset, encoder=encoder - ) if 0 < cfg.limited_label_train < 1: indices = get_subset_indices( @@ -198,8 +197,8 @@ def main(cfg: DictConfig) -> None: val_loader = DataLoader( val_dataset, sampler=DistributedSampler(val_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + batch_size=cfg.test_batch_size, + num_workers=cfg.test_num_workers, pin_memory=True, persistent_workers=False, worker_init_fn=seed_worker, @@ -237,16 +236,18 @@ def main(cfg: DictConfig) -> None: trainer.train() # Evaluation - for preprocess in cfg.preprocessing.test: - test_dataset: Dataset = instantiate( - preprocess, dataset=test_dataset, encoder=encoder - ) + test_preprocessor = instantiate(cfg.preprocessing.test, dataset_cfg=cfg.dataset, encoder_cfg=cfg.encoder, + _recursive_=False) + + # get datasets + raw_test_dataset: RawGeoFMDataset = instantiate(cfg.dataset, split="test") + test_dataset = GeoFMDataset(raw_test_dataset, test_preprocessor) test_loader = DataLoader( test_dataset, sampler=DistributedSampler(test_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + batch_size=cfg.test_batch_size, + num_workers=cfg.test_num_workers, pin_memory=True, persistent_workers=False, drop_last=False,