diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 8167c410..21f4d46d 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -6,6 +6,7 @@ We welcome all forms of contributions, including but not limited to the followin - Incorporate downstream datasets - Add new decoder heads - Fix typo or bugs +- Add new decoders ### Workflow @@ -16,6 +17,37 @@ We welcome all forms of contributions, including but not limited to the followin Note: For significant modifications or any bugs spotting, please consider opening an issue for discussion beforehand. +## Code structure + +### engines +In engines, basic modules in the training pipeline are defined including data_preprocessor, trainer and evaluator. +1. data_preprocessor selects the bands needed by an encoder and pads unavailable bands with zeros, and different augmentations. +2. trainer supports mixed precision/distributed training and print training stats and metrics in real time. +3. evaluator can be called independently and evaluate a model also in distributed way and compute per class metrics. + +### datasets +1. The implementations are simplified and standardized. +2. Dataset metas are read from configs, including newly added classes (name), ignore_index, and so on. +3. To add (register) a new dataset implementation, use the decorator ```@DATASET_REGISTRY.register()```. + +### foundation_models +1. Support multi-stage output that may be needed by segmentors, specified by output layers in encoder config. +2. All the encoder should work properly. +3. To add (register) a new encoder implementation, use the decorator ```@ENCODER_REGISTRY.register()```. + +### segmentors +1. The UperNet implementation is based on [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/main) +2. To add (register) a new encoder implementation, use the decorator ```@SEGMENTOR_REGISTRY.register()```. +3. So far, we have UPerNet for unitemporal semantic segmentation, UPerNetCD for change detection and MTUPerNet for multitemporal semantic segmentation +4. for multi-temporal, L-TAE and linear projection are supported + +### augmentations +1. All the available augmentations are in ```data_preproessor.py``` +2. To add (register) a new augmentation implementation, use the decorator ```@AUGMENTER_REGISTRY.register()```. + +All the parameters can also be set in the run config file. + +## Adding new features ### Adding a new geospatial foundation model 1. Inside the `foundation_models` folder: diff --git a/README.md b/README.md index d9a98bed..566d259b 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,51 @@ [![Tests](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml/badge.svg)](https://github.com/yurujaja/geofm-bench/actions/workflows/python-test.yml) -## Introduction -(TBD) - -### engines -In engines, basic modules in the training pipeline are defined including data_preprocessor, trainer and evaluator. -1. data_preprocessor replaced the previous adaptation.py, i.e., selects the bands needed by an encoder and pads unavailable bands with zeros, and different augmentations. -2. trainer now support mixed precision/distributed training and print training stats and metrics in real time. -3. evaluator can be called independently and evaluate a model also in distributed way and compute per class metrics. -4. see run.py for how to assemble these modules and concatenate them - -### datasets -1. The implementations are simplified and standardized (I try my best). -2. Dataset metas are read from configs, including newly added classes (name), ignore_index, and so on. -3.Mados, sen1floods, hlsburnscars, xView2, biomasster are supported by this branch currently. -4. To add (register) a new dataset implementation, use the decorator @DATASET_REGISTRY.register(). - -### foundation_models -1. Remove all irrelevant modules and functions used in pre-training. Only keep the essential modules in encoders for extracting features. -2. Support multi-stage output that may be needed by segmentors, specified by output layers in encoder config. -3. All the encoder should work properly. -4. To add (register) a new encoder implementation, use the decorator @ENCODER_REGISTRY.register(). - -### segmentors -1. Now the UperNet implementation is based on mmsegmentation, which is more likely correct: https://github.com/open-mmlab/mmsegmentation/tree/main -2. We can copypaste more segmentors later. -3. To add (register) a new encoder implementation, use the decorator @SEGMENTOR_REGISTRY.register(). -4. So far, we have UPerNet for unitemporal semantic segmentation, UPerNetCD for change detection and MTUPerNet for multitemporal semantic segmentation -5. for multi-temporal, L-TAE and linear projection are supported - -All of these parameters can also be set in the run config file. - -To use more gpus or nodes, set `--nnodes` and `--nproc_per_node` correspondingly, see: -https://pytorch.org/docs/stable/elastic/run.html - -To use mixed precision training, specify either `--fp16` for float16 and or `--bf16` for bfloat16 - -For fine-tuning instead of linear probing, specify `--finetune`. +# TITLE + +## 📚 Introduction + +While geospatial foundation models (GFMs) have proliferated rapidly, their evaluations remain inconsistent and narrow. Existing works often utilize suboptimal downstream datasets (e.g., EuroSAT) and tasks (e.g., land cover classification), which constrain comparability and real-world usability. Additionally, a lack of diversity in evaluation protocols, including image resolution and sensor types, further complicates the extensive assessments of GFM performance. To bridge this gap, we propose a standardized evaluation protocol that incorporates a wide-ranging selection of datasets, tasks, resolutions, and sensor types, establishing a robust and widely applicable benchmark for GFMs. + +In this repo, you can find the code to benchmark GFMs. For the moment we included several GFMs that present different approach. We look forward to adding new models and datasets. + +For the moment, we support the following **models**: + +| | Paper | GitHub | Keywords | +|:-----------:|:-----:|:------:|:--------:| +| SSL4EOS12 | | | | +| Scale-MAE | | | | +| SatlasNet | | | | +| GFM | | | | +| SpectralGPT | | | | +| DOFA | | | | +| CROMA | | | | +| Prithvi | | | | +| RemoteCLIP | | | | + +And the following **datasets**: + +| | Paper | Download | Domain | Task | Sensors | Location | +|:-------------------:|:-----:|:--------:|:------:|:----:|---------|----------| +| HLS Burn Scars | | | | | | | +| MADOS | | | | | | | +| PASTIS | | | | | | | +| Sen1Floods11 | | | | | | | +| xView2 | | | | | | | +| Five Billion Pixels | | | | | | | +| DynamicEarthNet | | | | | | | +| CropTypeMapping | | | | | | | +| SpaceNet7 | | | | | | | +| AI4SmallFarms | | | | | | | +| BioMassters | | | | | | | + +The repository supports the following **tasks** using GFMs: + - [single temporal semantic segmentation](#single-temporal-semantic-segmentation) + - [multi-temporal semantic segmentation](#multi-temporal-semantic-segmentation) + - [change detection](#change-detection) + - [single temporal regression](#single-temporal-regression) + - [multi-temporal regression](#multi-temporal-regression) + +It is possible also to train some [supervised baselines](#-fully-supervised-training), based on UNet. ## 🛠ī¸ Setup Clone the repository: @@ -63,16 +72,25 @@ mamba activate geofm-bench8 ``` ## 🏋ī¸ Training + There are 5 basic component types in our config system: -- `config`: Information of training settings such as batch size, epochs, use wandb. `limited_label` is to indicate the percentage of dataset used for training, for example, `-1` means the full training dataset is used while `0.5` means 50% used. + +- `config`: Information of training settings such as batch size, epochs, use wandb. `limited_label` is to indicate the percentage of dataset used for training, for example, `-1` means the full training dataset is used while `0.5` means 50% used. #strategy used - `encoder_config`: GFM encoder related parameters. `output_layers` is used for which layers are used for Upernet decoder. - `dataset_config`: Information of downstream datasets such as image size, band_statistics, etc. - `segmentor_config`: Downstream task decoder fine-tuning related parameters, including the head type, loss, optimizer, scheduler, etc. - `augmentation_config`: Both preprocessing and augmentations steps required for the dataset, such as bands adaptation, normalization, resize/crop. -We provide several examples of command lines to initilize different training tasks on single gpu. +We provide several examples of command lines to initilize different training tasks on single GPU. + +Please note: + - Command line's parameters have the priority on the parameters in the config files. So, if you want to change e.g. the `batch size`, without changing the `config`, you can just add `--batch size n` to the command line + - To use more gpus or nodes, set `--nnodes` and `--nproc_per_node` correspondingly, see: +https://pytorch.org/docs/stable/elastic/run.html + - To use mixed precision training, specify either `--fp16` for float16 and or `--bf16` for bfloat16 + ### đŸ’ģ Decoder Finetuning -**Single Temporal Semantic Segmentation** +#### Single Temporal Semantic Segmentation Take MADOS dataset, Prithvi Encoder and Upernet Decoder as example: ``` @@ -85,7 +103,7 @@ torchrun --nnodes=1 --nproc_per_node=1 run.py \ --num_workers 4 --eval_interval 1 --use_wandb ``` -**Multi Temporal Semantic Segmentation** +#### Multi-Temporal Semantic Segmentation Multi-temporal model `configs/segmentors/upernet_mt.yaml` should be used. In addition, in the dataset config, indicate the number of time frames, e.g., `multi_temporal: 6` ``` @@ -98,40 +116,94 @@ torchrun --nnodes=1 --nproc_per_node=1 run.py \ --num_workers 4 --eval_interval 1 --use_wandb ``` -**Multi Temporal Change Detection** +#### Change Detection +``` +torchrun ... +``` +#### Single Temporal Regression ``` torchrun ... ``` -**Multi Temporal Regression** +#### Multi-Temporal Regression ``` torchrun ... ``` ### đŸ’ģ Fully Supervised Training -**Single Temporal Change Detection** +#### Single Temporal Semantic Segmentation ``` torchrun ... ``` +In general + + ## 🏃 Evaluation Indicate the `eval_dir` where the checkpoints and configurations are stored. + ``` torchrun --nnodes=1 --nproc_per_node=1 run.py --batch_size 1 --eval_dir work-dir/the-folder-where-your-exp-is-saved ``` - ## ✏ī¸ Contributing -We appreciate all contributions to improve xxx. Please refer to [Contributing Guidelines](.github/CONTRIBUTING.md) +We appreciate all contributions. Please refer to [Contributing Guidelines](.github/CONTRIBUTING.md) + +## ⚠ī¸ Warnings +Some features are under construction: + - the automatic download is working for all the datasets and models' weights but, respectively, **Five Billion Pixels**, **BioMassters**, and **GFM**. +## 🧮 Some first results -## Some numbers +A pre-print is coming soon... Stay tuned! -| Encoder | Dataset | Epochs | mIoU | -|---------|--------------|--------|--------| -| Prithvi | MADOS | 80 | 53.455 | -| Prithvi | HLSBurnScars | 80 | 86.208 | -| Prithvi | Sen1Floods11 | 80 | 87.217 | +| Encoder | Dataset | Epochs | mIoU | +|---------|---------------|--------|--------| +| Prithvi | MADOS | 80 | 53.455 | +| Prithvi | HLSBurnScars | 80 | 86.208 | +| Prithvi | Sen1Floods11 | 80 | 87.217 | +| Prithvi | AI4SmallFarms | 80 | 33.796 | + +Please note: #add different conditions ## 💡 Acknowledgements + +## Šī¸ License + +MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE + +## 📝 Citing + +If you use this software in your work, please cite: + +``` +@misc{pangaea, + author = {}, + title = {Pangaea}, + year = {2024}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/yurujaja/geofm-bench}}, +} +``` diff --git a/configs/augmentations/pastis_seg.yaml b/configs/augmentations/pastis_seg.yaml new file mode 100644 index 00000000..8b74e954 --- /dev/null +++ b/configs/augmentations/pastis_seg.yaml @@ -0,0 +1,11 @@ +train: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ResizeToEncoder: ~ + # RandomFlip: + # ud_probability: 0.3 + # lr_probability: 0.3 +test: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ResizeToEncoder: ~ diff --git a/configs/augmentations/regression_default.yaml b/configs/augmentations/regression_default.yaml index 584be90d..2f8dd0fa 100644 --- a/configs/augmentations/regression_default.yaml +++ b/configs/augmentations/regression_default.yaml @@ -1,13 +1,11 @@ train: RegPreprocessor: ~ - NormalizeMinMax: ~ -# NormalizeMeanStd: ~ + NormalizeMeanStd: ~ RandomCropToEncoder: ~ -# RandomFlip: -# ud_probability: 0.3 -# lr_probability: 0.3 + RandomFlip: + ud_probability: 0.3 + lr_probability: 0.3 test: RegPreprocessor: ~ - NormalizeMinMax: ~ -# NormalizeMeanStd: ~ - Tile: ~ + NormalizeMeanStd: ~ + Tile: ~ \ No newline at end of file diff --git a/configs/datasets/dynamicen.yaml b/configs/datasets/dynamicen.yaml index 190c10c5..46b75340 100644 --- a/configs/datasets/dynamicen.yaml +++ b/configs/datasets/dynamicen.yaml @@ -4,7 +4,7 @@ download_url: None auto_download: False img_size: 1024 -multi_temporal: False +multi_temporal: 6 multi_modal: False # classes diff --git a/configs/datasets/fivebillionpixels.yaml b/configs/datasets/fivebillionpixels.yaml index 315915d8..26336011 100644 --- a/configs/datasets/fivebillionpixels.yaml +++ b/configs/datasets/fivebillionpixels.yaml @@ -2,6 +2,7 @@ dataset_name: FiveBillionPixels root_path: ./data/FiveBillionPixels/cropped download_url: False auto_download: False +use_cmyk: False img_size: 520 multi_temporal: False @@ -37,50 +38,50 @@ classes: - railway station - airport distribution: - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 - - 0.88889 - - 0.11111 + - 0. + - 0.0368 + - 0.0253 + - 0.3567 + - 0.0752 + - 0.0095 + - 0.0694 + - 0.0096 + - 0.0004 + - 0.0055 + - 0.0025 + - 0.0568 + - 0.0548 + - 0.1396 + - 0.0102 + - 0.0129 + - 0.0004 + - 0.0456 + - 0.0447 + - 0.0003 + - 0.0002 + - 0.0383 + - 0.0025 + - 0.0007 + - 0.0011 bands: optical: + - B8 - B4 - B3 - B2 - - B8 data_mean: optical: + - 92.6 - 124.3 - 94.2 - 98. - - 92.6 data_std: optical: + - 44.5 - 51. - 50. - 47.1 - - 44.5 data_min: optical: [0.0000, 0.0000, 0.0000, 0.0000] data_max: diff --git a/configs/datasets/fivebillionpixels_cross_sensors.yaml b/configs/datasets/fivebillionpixels_cross_sensors.yaml new file mode 100644 index 00000000..5f5cd003 --- /dev/null +++ b/configs/datasets/fivebillionpixels_cross_sensors.yaml @@ -0,0 +1,87 @@ +dataset_name: FiveBillionPixels +root_path: /geomatics/gpuserver-1/vmarsocci/FiveBillionPixels/dom_adapt +download_url: False +auto_download: False + +img_size: 1000 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: 0 +num_classes: 25 +classes: +- unlabeled +- industrial area +- paddy field +- irrigated field +- dry cropland +- garden land +- arbor forest +- shrub forest +- park +- natural meadow +- artificial meadow +- river +- urban residential +- lake +- pond +- fish pond +- snow +- bareland +- rural residential +- stadium +- square +- road +- overpass +- railway station +- airport +distribution: + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + - 0.88889 + - 0.11111 + +bands: + optical: + - B4 + - B3 + - B2 + - B8 +data_mean: + optical: + - 124.3 + - 94.2 + - 98. + - 92.6 +data_std: + optical: + - 51. + - 50. + - 47.1 + - 44.5 +data_min: + optical: [0.0000, 0.0000, 0.0000, 0.0000] +data_max: + optical: [0.0000, 0.0000, 0.0000, 0.0000] diff --git a/configs/datasets/pastis.yaml b/configs/datasets/pastis.yaml new file mode 100644 index 00000000..5640f252 --- /dev/null +++ b/configs/datasets/pastis.yaml @@ -0,0 +1,110 @@ +dataset_name: Pastis +root_path: ./data/PASTIS-HD +download_url: null +auto_download: False + +img_size: 128 +multi_temporal: 6 +multi_modal: True +limited_label: False + +# classes +ignore_index: 0 +num_classes: 20 +classes: + - Background + - Meadow + - Soft Winter Wheat + - Corn + - Winter Barley + - Winter Rapeseed + - Spring Barley + - Sunflower + - Grapevine + - Beet + - Winter Triticale + - Winter Durum Wheat + - Fruits, Vegetables, Flowers + - Potatoes + - Leguminous Fodder + - Soybeans + - Orchard + - Mixed Cereal + - Sorghum + - Void Label +distribution: + - 0.00000 + - 0.25675 + - 0.06733 + - 0.10767 + - 0.02269 + - 0.01451 + - 0.00745 + - 0.01111 + - 0.08730 + - 0.00715 + - 0.00991 + - 0.01398 + - 0.02149 + - 0.00452 + - 0.02604 + - 0.00994 + - 0.02460 + - 0.00696 + - 0.00580 + - 0.29476 + +bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + sar: + - VV + - VH + - VV-VH +data_mean: + optical: + - 1161.6764 + - 1371.4307 + - 1423.4067 + - 1759.7251 + - 2714.5259 + - 3055.8376 + - 3197.8960 + - 3313.3577 + - 2415.9675 + - 1626.8431 + sar: + - -10.9433 + - -17.3600 + - 6.4167 +data_std: + optical: + - 2045.0698 + - 1983.1763 + - 2060.7969 + - 1968.8173 + - 1867.2159 + - 1885.1361 + - 1897.5105 + - 1885.1636 + - 1542.7665 + - 1375.2511 + sar: + - 3.3847 + - 3.3727 + - 3.3874 +data_min: + optical: [-10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.] + sar: [-39.5312, -43.1250, -20.6562] +data_max: + optical: [22256., 21891., 22626., 21814., 20134., 19282., 18957., 18482., 16935., 14668.] + sar: [34.8750, 27.3125, 51.9062] diff --git a/configs/datasets/pastis_si.yaml b/configs/datasets/pastis_si.yaml new file mode 100644 index 00000000..43d990c5 --- /dev/null +++ b/configs/datasets/pastis_si.yaml @@ -0,0 +1,110 @@ +dataset_name: Pastis +root_path: ./data/PASTIS-HD +download_url: null +auto_download: False + +img_size: 128 +multi_temporal: 1 +multi_modal: True +limited_label: False + +# classes +ignore_index: 0 +num_classes: 20 +classes: + - Background + - Meadow + - Soft Winter Wheat + - Corn + - Winter Barley + - Winter Rapeseed + - Spring Barley + - Sunflower + - Grapevine + - Beet + - Winter Triticale + - Winter Durum Wheat + - Fruits, Vegetables, Flowers + - Potatoes + - Leguminous Fodder + - Soybeans + - Orchard + - Mixed Cereal + - Sorghum + - Void Label +distribution: + - 0.00000 + - 0.25675 + - 0.06733 + - 0.10767 + - 0.02269 + - 0.01451 + - 0.00745 + - 0.01111 + - 0.08730 + - 0.00715 + - 0.00991 + - 0.01398 + - 0.02149 + - 0.00452 + - 0.02604 + - 0.00994 + - 0.02460 + - 0.00696 + - 0.00580 + - 0.29476 + +bands: + optical: + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B11 + - B12 + sar: + - VV + - VH + - VV-VH +data_mean: + optical: + - 1161.6764 + - 1371.4307 + - 1423.4067 + - 1759.7251 + - 2714.5259 + - 3055.8376 + - 3197.8960 + - 3313.3577 + - 2415.9675 + - 1626.8431 + sar: + - -10.9433 + - -17.3600 + - 6.4167 +data_std: + optical: + - 2045.0698 + - 1983.1763 + - 2060.7969 + - 1968.8173 + - 1867.2159 + - 1885.1361 + - 1897.5105 + - 1885.1636 + - 1542.7665 + - 1375.2511 + sar: + - 3.3847 + - 3.3727 + - 3.3874 +data_min: + optical: [-10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.] + sar: [-39.5312, -43.1250, -20.6562] +data_max: + optical: [22256., 21891., 22626., 21814., 20134., 19282., 18957., 18482., 16935., 14668.] + sar: [34.8750, 27.3125, 51.9062] diff --git a/datasets/__init__.py b/datasets/__init__.py index 026a3837..7d85797c 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,10 +1,11 @@ -from .mados import MADOS -from .hlsburnscars import HLSBurnScars -from .sen1floods11 import Sen1Floods11 -from .xview2 import xView2 +from .ai4smallfarms import AI4SmallFarms from .biomassters import BioMassters from .croptypemapping import CropTypeMappingSouthSudan -from .ai4smallfarms import AI4SmallFarms -from .spacenet7 import SN7MAPPING, SN7CD from .fivebillionpixels import FiveBillionPixels +from .hlsburnscars import HLSBurnScars +from .mados import MADOS +from .pastis import Pastis +from .sen1floods11 import Sen1Floods11 +from .spacenet7 import SN7CD, SN7MAPPING from .utae_dynamicen import DynamicEarthNet +from .xview2 import xView2 diff --git a/datasets/fivebillionpixels.py b/datasets/fivebillionpixels.py index c8e53991..dedd6ef1 100644 --- a/datasets/fivebillionpixels.py +++ b/datasets/fivebillionpixels.py @@ -34,90 +34,33 @@ def __init__(self, cfg, split, is_train = True): """ super().__init__() self._base_dir = cfg['root_path'] - # print(os.path.join(self._base_dir, split, 'imgs', '*.tif')) - # print(os.path.join(self._base_dir, split, 'labels', '*.tif')) - # print(self._image_dir) - # print(self._label_dir) - # _splits_dir = os.path.join(self._base_dir, 'list') - # self.split = [split] - - # self.args = args - - # self.im_ids = [] - # self.images = [] - # self.labels = [] - - # for splt in self.split: - # with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: - # lines = f.read().splitlines() - - # if splt == 'train': - # lines = random.sample(lines, len(os.listdir(os.path.join(args.target_dir, args.target)))) - # elif split == 'val': - # lines = random.sample(lines, 500) - # self.root_path = cfg['root_path'] self.data_mean = cfg['data_mean'] self.data_std = cfg['data_std'] self.classes = cfg['classes'] + self.use_cmyk = cfg['use_cmyk'] self.class_num = len(self.classes) self.split = split self.is_train = is_train self._image_dir = sorted(glob(os.path.join(self._base_dir, self.split, 'imgs', '*.tif'))) self._label_dir = sorted(glob(os.path.join(self._base_dir, self.split, 'labels', '*.tif'))) - # print(split) - # print(os.path.join(self._base_dir, self.split, 'imgs', '*.tif')) - # print(os.path.join(self._base_dir, self.split, 'labels', '*.png')) - # print(self._image_dir) - # print((self._label_dir)) - # print(len(self._image_dir)) - # print(len(self._label_dir)) - - # self.split_mapping = {'train': 'training', 'val': 'validation', 'test': 'validation'} - - # self.image_list = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*merged.tif'))) - # self.target_list = sorted(glob(os.path.join(self.root_path, self.split_mapping[self.split], '*mask.tif'))) - - - # for ii, line in enumerate(lines): - # _image = os.path.join(self._image_dir, line + ".tif") - # _label = os.path.join(self._label_dir, line + ".png") - # assert os.path.isfile(_image) - # assert os.path.isfile(_label) - # self.im_ids.append(line) - # self.images.append(_image) - # self.labels.append(_label) - - # assert (len(self.images) == len(self.labels)) - - # Display stats - # print('Number of images in {}: {:d}'.format(split, len(self.images))) def __len__(self): return len(self._image_dir) def __getitem__(self, index): - # _img, _target = self._make_img_gt_point_pair(index) - # print(index) - # image = Image.open(self._image_dir[index]).convert('CMYK') #check it also on the normalization - # target = Image.open(self._label_dir[index]) - - image = tiff.imread(self._image_dir[index])#.convert('CMYK') #check it also on the normalization - target = tiff.imread(self._label_dir[index]) #, cv2.IMREAD_UNCHANGED) - # image = TF.pil_to_tensor(image) - # target = TF.pil_to_tensor(target).squeeze(0).to(torch.int64) - - image = image.astype(np.float32) # Convert to float32 + if self.use_cmyk: + image = Image.open(self._image_dir[index]).convert('CMYK') + image = TF.pil_to_tensor(image) + else: + image = tiff.imread(self._image_dir[index])#.convert('CMYK') #check it also on the normalization + image = image.astype(np.float32) # Convert to float32 + image = torch.from_numpy(image).permute(2, 0, 1) + + target = tiff.imread(self._label_dir[index]) target = target.astype(np.int64) # Convert to int64 (since it's a mask) - - # Tile the image and target to the fixed size specified in the config - # image, target = self.tile_image_and_mask(image, target, self.img_size) - - image = torch.from_numpy(image).permute(2, 0, 1) target = torch.from_numpy(target).long() - # print(image.shape) - # print(target.shape) output = { 'image': { @@ -169,7 +112,7 @@ def __getitem__(self, index): def get_splits(dataset_config): dataset_train = FiveBillionPixels(dataset_config, split="train", is_train=True) dataset_val = FiveBillionPixels(dataset_config, split="val", is_train=False) - dataset_test = dataset_val + dataset_test = FiveBillionPixels(dataset_config, split="test", is_train=False) return dataset_train, dataset_val, dataset_test @staticmethod diff --git a/datasets/mados.py b/datasets/mados.py index 9e917217..62a4826d 100644 --- a/datasets/mados.py +++ b/datasets/mados.py @@ -23,7 +23,6 @@ from .utils import DownloadProgressBar from utils.registry import DATASET_REGISTRY -import matplotlib.pyplot as plt ############################################################### @@ -156,4 +155,4 @@ def get_splits(dataset_config): dataset_train = MADOS(cfg=dataset_config, split="train", is_train=True) dataset_val = MADOS(cfg=dataset_config, split="val", is_train=False) dataset_test = MADOS(cfg=dataset_config, split="test", is_train=False) - return dataset_train, dataset_val, dataset_test + return dataset_train, dataset_val, dataset_test \ No newline at end of file diff --git a/datasets/pastis.py b/datasets/pastis.py new file mode 100644 index 00000000..5313a951 --- /dev/null +++ b/datasets/pastis.py @@ -0,0 +1,387 @@ +### +# Modified version of the PASTIS-HD dataset +# original code https://github.com/gastruc/OmniSat/blob/main/src/data/Pastis.py +### + +import json +import os +from datetime import datetime + +import geopandas as gpd +import numpy as np +import pandas as pd +import rasterio +import torch +from einops import rearrange +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from utils.registry import DATASET_REGISTRY + + +def prepare_dates(date_dict, reference_date): + """Date formating.""" + if type(date_dict) == str: + date_dict = json.loads(date_dict) + d = pd.DataFrame().from_dict(date_dict, orient="index") + d = d[0].apply( + lambda x: ( + datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:])) + - reference_date + ).days + ) + return torch.tensor(d.values) + + +def split_image(image_tensor, nb_split, id): + """ + Split the input image tensor into four quadrants based on the integer i. + To use if Pastis data does not fit in your GPU memory. + Returns the corresponding quadrant based on the value of i + """ + if nb_split == 1: + return image_tensor + i1 = id // nb_split + i2 = id % nb_split + height, width = image_tensor.shape[-2:] + half_height = height // nb_split + half_width = width // nb_split + if image_tensor.dim() == 4: + return image_tensor[ + :, + :, + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + if image_tensor.dim() == 3: + return image_tensor[ + :, + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + if image_tensor.dim() == 2: + return image_tensor[ + i1 * half_height : (i1 + 1) * half_height, + i2 * half_width : (i2 + 1) * half_width, + ].float() + + +@DATASET_REGISTRY.register() +class Pastis(Dataset): + def __init__( + self, + cfg: OmegaConf, + split: str, + is_train: bool = True, + ): + """ + Initializes the dataset. + Args: + path (str): path to the dataset + modalities (list): list of modalities to use + folds (list): list of folds to use + reference_date (str date): reference date for the data + nb_split (int): number of splits from one observation + num_classes (int): number of classes + """ + super(Pastis, self).__init__() + + if split == "train": + folds = [1, 2, 3] + elif split == "val": + folds = [4] + elif split == "test": + folds = [5] + + self.split = split + self.path = cfg["root_path"] + self.data_mean = cfg["data_mean"] + self.data_std = cfg["data_std"] + self.data_min = cfg["data_min"] + self.data_max = cfg["data_max"] + self.classes = cfg["classes"] + self.class_num = len(self.classes) + self.grid_size = cfg["multi_temporal"] + self.modalities = ["s2", "aerial", "s1-asc"] + self.nb_split = 1 + + reference_date = "2018-09-01" + self.reference_date = datetime(*map(int, reference_date.split("-"))) + + self.meta_patch = gpd.read_file(os.path.join(self.path, "metadata.geojson")) + + self.num_classes = 20 + + if folds is not None: + self.meta_patch = pd.concat( + [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds] + ) + + def __getitem__(self, i): + """ + Returns an item from the dataset. + Args: + i (int): index of the item + Returns: + dict: dictionary with keys "label", "name" and the other corresponding to the modalities used + """ + line = self.meta_patch.iloc[i // (self.nb_split * self.nb_split)] + name = line["ID_PATCH"] + part = i % (self.nb_split * self.nb_split) + label = torch.from_numpy( + np.load( + os.path.join(self.path, "ANNOTATIONS/TARGET_" + str(name) + ".npy") + )[0].astype(np.int32) + ) + # label = torch.unique(split_image(label, self.nb_split, part)).long() + # label = torch.sum( + # torch.nn.functional.one_hot(label, num_classes=self.num_classes), dim=0 + # ) + # label = label[1:-1] # remove Background and Void classes + output = {"label": label, "name": name} + + for modality in self.modalities: + if modality == "aerial": + with rasterio.open( + os.path.join( + self.path, + "DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_" + + str(name) + + ".tif", + ) + ) as f: + output["aerial"] = split_image( + torch.FloatTensor(f.read()), self.nb_split, part + ) + elif modality == "s1-median": + modality_name = "s1a" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + out, _ = torch.median(images, dim=0) + output[modality] = out + elif modality == "s2-median": + modality_name = "s2" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + out, _ = torch.median(images, dim=0) + output[modality] = out + elif modality == "s1-4season-median": + modality_name = "s1a" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + dates = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + l = [] + for i in range(4): + mask = (dates >= 92 * i) & (dates < 92 * (i + 1)) + if sum(mask) > 0: + r, _ = torch.median(images[mask], dim=0) + l.append(r) + else: + l.append( + torch.zeros( + (images.shape[1], images.shape[-2], images.shape[-1]) + ) + ) + output[modality] = torch.cat(l) + elif modality == "s2-4season-median": + modality_name = "s2" + images = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ).to(torch.float32) + dates = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + l = [] + for i in range(4): + mask = (dates >= 92 * i) & (dates < 92 * (i + 1)) + if sum(mask) > 0: + r, _ = torch.median(images[mask], dim=0) + l.append(r) + else: + l.append( + torch.zeros( + (images.shape[1], images.shape[-2], images.shape[-1]) + ) + ) + output[modality] = torch.cat(l) + else: + if len(modality) > 3: + modality_name = modality[:2] + modality[3] + output[modality] = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality_name.upper()), + "{}_{}.npy".format(modality_name.upper(), name), + ) + ) + ), + self.nb_split, + part, + ) + output["_".join([modality, "dates"])] = prepare_dates( + line["-".join(["dates", modality_name.upper()])], + self.reference_date, + ) + else: + output[modality] = split_image( + torch.from_numpy( + np.load( + os.path.join( + self.path, + "DATA_{}".format(modality.upper()), + "{}_{}.npy".format(modality.upper(), name), + ) + ) + ), + self.nb_split, + part, + ) + output["_".join([modality, "dates"])] = prepare_dates( + line["-".join(["dates", modality.upper()])], self.reference_date + ) + N = len(output[modality]) + if N > 50: + random_indices = torch.randperm(N)[:50] + output[modality] = output[modality][random_indices] + output["_".join([modality, "dates"])] = output[ + "_".join([modality, "dates"]) + ][random_indices] + + optical_ts = rearrange(output["s2"], "t c h w -> c t h w") + sar_ts = rearrange(output["s1-asc"], "t c h w -> c t h w") + + if self.grid_size == 1: + # we only take the last frame + optical_ts = optical_ts[:, -1] + sar_ts = sar_ts[:, -1] + else: + # select evenly spaced samples + optical_indexes = torch.linspace( + 0, optical_ts.shape[1] - 1, self.grid_size, dtype=torch.long + ) + sar_indexes = torch.linspace( + 0, sar_ts.shape[1] - 1, self.grid_size, dtype=torch.long + ) + + optical_ts = optical_ts[:, optical_indexes] + sar_ts = sar_ts[:, sar_indexes] + + return { + "image": { + "optical": optical_ts.to(torch.float32), + "sar": sar_ts.to(torch.float32), + }, + "target": output["label"], + "metadata": {}, + } + + def __len__(self) -> int: + return len(self.meta_patch) * self.nb_split * self.nb_split + + @staticmethod + def get_splits(dataset_config): + dataset_train = Pastis(cfg=dataset_config, split="train", is_train=True) + dataset_val = Pastis(cfg=dataset_config, split="val", is_train=False) + dataset_test = Pastis(cfg=dataset_config, split="test", is_train=False) + return dataset_train, dataset_val, dataset_test + + @staticmethod + def download(dataset_config: dict, silent=False): + pass + + +if __name__ == "__main__": + class_prob = { + "Background": 0.0, + "Meadow": 31292, + "Soft Winter Wheat": 8206, + "Corn": 13123, + "Winter Barley": 2766, + "Winter Rapeseed": 1769, + "Spring Barley": 908, + "Sunflower": 1355, + "Grapevine": 10640, + "Beet": 871, + "Winter Triticale": 1208, + "Winter Durum Wheat": 1704, + "Fruits, Vegetables, Flowers": 2619, + "Potatoes": 551, + "Leguminous Fodder": 3174, + "Soybeans": 1212, + "Orchard": 2998, + "Mixed Cereal": 848, + "Sorghum": 707, + "Void Label": 35924, + } + + # get the class weights + class_weights = np.array([class_prob[key] for key in class_prob.keys()]) + class_weights = class_weights / class_weights.sum() + print("Class weights: ") + for i, key in enumerate(class_prob.keys()): + print(key, "->", class_weights[i]) + print("_" * 100) + + cfg = { + "root_path": "/share/DEEPLEARNING/datasets/PASTIS-HD", + "data_mean": None, + "data_std": None, + "classes": { + "0": "Background", + "1": "Meadow", + }, + "data_min": 0, + "data_max": 1, + } + + dataset = Pastis(cfg, "train", is_train=True) + train_dataset, val_dataset, test_dataset = Pastis.get_splits(cfg) diff --git a/datasets/utae_dynamicen.py b/datasets/utae_dynamicen.py index 3747a643..ca478da4 100644 --- a/datasets/utae_dynamicen.py +++ b/datasets/utae_dynamicen.py @@ -36,7 +36,7 @@ def __init__(self, cfg, split, is_train=True): self.split = split self.is_train = is_train - self.mode = 'single' + self.mode = 'weekly' self.files = [] diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index 7eacfc8d..c9d4633b 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -1,18 +1,16 @@ import random - import math import torch import torch.nn.functional as F import torchvision.transforms as T +import torchvision.transforms.functional as TF from typing import Callable import numpy as np import logging - import omegaconf - from utils.registry import AUGMENTER_REGISTRY @@ -493,18 +491,14 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: - data["image"][k] = T.Resize(self.size)(v) + data["image"][k] = T.resize(v, self.size, interpolation=T.InterpolationMode.BILINEAR, antialias=True) if data["target"].ndim == 2: data["target"] = data["target"].unsqueeze(0) - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + data["target"] = T.resize(data["target"], self.size, interpolation=T.InterpolationMode.NEAREST) data["target"] = data["target"].squeeze(0) else: - data["target"] = T.Resize( - self.size, interpolation=T.InterpolationMode.NEAREST - )(data["target"]) + data["target"] = T.resize(data["target"], self.size, interpolation=T.InterpolationMode.NEAREST) return data diff --git a/engine/trainer.py b/engine/trainer.py index 3da2754e..4c2df54e 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -25,6 +25,7 @@ def __init__(self, args, model, train_loader, criterion, optimizer, lr_scheduler self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.evaluator = evaluator + self.ignore_index = args["dataset"]["ignore_index"] self.logger = logging.getLogger() self.training_stats = {name: RunningAverageMeter(length=self.batch_per_epoch) for name in ['loss', 'data_time', 'batch_time', 'eval_time']} self.training_metrics = {} @@ -91,22 +92,22 @@ def train_one_epoch(self, epoch): with torch.cuda.amp.autocast(enabled=self.enable_mixed_precision, dtype=self.precision): logits = self.model(image, output_shape=target.shape[-2:]) loss = self.compute_loss(logits, target) - self.compute_logging_metrics(logits.detach().clone(), target.detach().clone()) self.optimizer.zero_grad() - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + if not torch.isnan(loss): + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + self.training_stats['loss'].update(loss.item()) + with torch.no_grad(): + self.compute_logging_metrics(logits, target) + if (batch_idx + 1) % self.args.log_interval == 0: + self.log(batch_idx + 1, epoch) + else: + self.logger.warning("Skip batch {} because of nan loss".format(batch_idx + 1)) self.lr_scheduler.step() - self.training_stats['loss'].update(loss.item()) - if (batch_idx + 1) % self.args.log_interval == 0: - self.log(batch_idx + 1, epoch) - self.training_stats['batch_time'].update(time.time() - end_time) - #print(self.training_stats['batch_time'].val, self.training_stats['batch_time'].avg) - end_time = time.time() - if self.use_wandb and self.rank == 0: self.wandb.log( { @@ -121,6 +122,9 @@ def train_one_epoch(self, epoch): step=epoch * len(self.train_loader) + batch_idx, ) + self.training_stats['batch_time'].update(time.time() - end_time) + end_time = time.time() + def get_checkpoint(self, epoch): checkpoint = { "model": self.model.module.state_dict(), @@ -234,7 +238,7 @@ def compute_logging_metrics(self, logits, target): else: pred = torch.argmax(logits, dim=1, keepdim=True) target = target.unsqueeze(1) - ignore_mask = target == -1 + ignore_mask = target == self.ignore_index target[ignore_mask] = 0 ignore_mask = ignore_mask.expand(-1, num_classes if num_classes > 1 else 2, -1, -1) @@ -276,35 +280,8 @@ def compute_loss(self, logits, target): @torch.no_grad() def compute_logging_metrics(self, logits, target): - # logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear') - # print(logits.shape) - # print(target.shape) - mse = F.mse_loss(logits.squeeze(dim=1), target) - - # pred = torch.argmax(logits, dim=1, keepdim=True) - # target = target.unsqueeze(1) - # ignore_mask = target == -1 - # target[ignore_mask] = 0 - # ignore_mask = ignore_mask.expand(-1, logits.shape[1], -1, -1) - - # binary_pred = torch.zeros(logits.shape, dtype=bool, device=self.device) - # binary_target = torch.zeros(logits.shape, dtype=bool, device=self.device) - # binary_pred.scatter_(dim=1, index=pred, src=torch.ones_like(binary_pred)) - # binary_target.scatter_(dim=1, index=target, src=torch.ones_like(binary_target)) - # binary_pred[ignore_mask] = 0 - # binary_target[ignore_mask] = 0 - - # intersection = torch.logical_and(binary_pred, binary_target) - # union = torch.logical_or(binary_pred, binary_target) - - # acc = intersection.sum() / binary_target.sum() * 100 - # macc = torch.nanmean(intersection.sum(dim=(0, 2, 3)) / binary_target.sum(dim=(0, 2, 3))) * 100 - # miou = torch.nanmean(intersection.sum(dim=(0, 2, 3)) / union.sum(dim=(0, 2, 3))) * 100 - self.training_metrics['MSE'].update(mse.item()) - # self.training_metrics['mAcc'].update(macc.item()) - # self.training_metrics['mIoU'].update(miou.item()) diff --git a/environment.yaml b/environment.yaml index 333eca5a..1c486cf8 100644 --- a/environment.yaml +++ b/environment.yaml @@ -28,4 +28,4 @@ dependencies: - google-cloud-storage - omegaconf - pydataverse - - pytest + - pytest \ No newline at end of file diff --git a/foundation_models/prithvi_encoder.py b/foundation_models/prithvi_encoder.py index 6f94ca38..79508983 100644 --- a/foundation_models/prithvi_encoder.py +++ b/foundation_models/prithvi_encoder.py @@ -108,11 +108,7 @@ def forward(self, image): for i, blk in enumerate(self.blocks): x = blk(x) if i in self.output_layers: - #out = self.norm(x) if i == 11 else x - # print(x.shape) out = x[:, 1:, :].permute(0, 2, 1).view(x.shape[0], -1, self.num_frames, self.img_size // self.patch_size, self.img_size // self.patch_size).squeeze(2).contiguous() - # out = x[:, 1:, :].permute(0, 2, 1).reshape(x.shape[0], -1, self.img_size // self.patch_size, self.img_size // self.patch_size).contiguous() - output.append(out) return output diff --git a/foundation_models/spectralgpt_encoder.py b/foundation_models/spectralgpt_encoder.py index 50f3aec1..220bd625 100644 --- a/foundation_models/spectralgpt_encoder.py +++ b/foundation_models/spectralgpt_encoder.py @@ -35,7 +35,7 @@ def __init__(self, self.model_name = "SpectralGPT" self.output_layers = cfg['output_layers'] - self.num_frames = cfg['multi_temporal'] if cfg['multi_temporal'] else 1 + self.num_frames = 1 self.patch_embed = PatchEmbed( img_size, patch_size, self.num_frames, embed_dim, in_chans, t_patch_size) diff --git a/foundation_models/unet_encoder.py b/foundation_models/unet_encoder.py index a220e182..1430c68f 100644 --- a/foundation_models/unet_encoder.py +++ b/foundation_models/unet_encoder.py @@ -21,7 +21,7 @@ def __init__(self, cfg, in_channels: int, topology: Sequence[int]): self.encoder = Encoder(self.topology) def forward(self, image): - x = image['optical'].squeeze() + x = image['optical'] feat = self.in_conv(x) output = self.encoder(feat) return output diff --git a/run.py b/run.py index 21c11cb7..4f20c1da 100644 --- a/run.py +++ b/run.py @@ -149,6 +149,16 @@ def main(): exp_dir.mkdir(parents=True, exist_ok=True) logger_path = exp_dir / "train.log" + if cfg.use_wandb and cfg.rank == 0: + import wandb + # initialize new wandb run + wandb.init( + project="geofm-bench", + name=exp_name, + config=OmegaConf.to_container(cfg, resolve=True), + ) + cfg['wandb_run_id'] = wandb.run.id + config_log_dir = exp_dir / "configs" config_log_dir.mkdir(exist_ok=True) OmegaConf.save(cfg, config_log_dir / "config.yaml") @@ -157,25 +167,22 @@ def main(): exp_name = exp_dir.name logger_path = exp_dir / "test.log" + if cfg.use_wandb and cfg.rank == 0: + import wandb + # resume wandb run + wandb.init( + project="geofm-bench", + name=exp_name, + id=cfg.get('wandb_run_id'), + resume='allow', + ) + logger = init_logger(logger_path, rank=cfg.rank) logger.info("============ Initialized logger ============") logger.info(pprint.pformat(OmegaConf.to_container(cfg), compact=True).strip("{}")) logger.info("The experiment is stored in %s\n" % exp_dir) logger.info(f"Device used: {device}") - # init wandb - if cfg.use_wandb and cfg.rank == 0: - import wandb - - wandb.init( - project="geofm-bench", - name=exp_name, - config=OmegaConf.to_container(cfg, resolve=True), - resume='allow', - id=cfg.get('wandb_run_id'), - ) - cfg['wandb_run_id'] = wandb.run.id - # get datasets dataset = DATASET_REGISTRY.get(cfg.dataset.dataset_name) dataset.download(cfg.dataset, silent=False) @@ -241,8 +248,6 @@ def main(): collate_fn = get_collate_fn(cfg) # training if not cfg.eval_dir: - - if 0 < cfg.limited_label < 1: indices = random.sample(range(len(train_dataset)), int(len(train_dataset)*cfg.limited_label)) train_dataset = Subset(train_dataset, indices) @@ -350,30 +355,30 @@ def main(): trainer.train() # Evaluation - else: - test_loader = DataLoader( - test_dataset, - sampler=DistributedSampler(test_dataset), - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - pin_memory=True, - persistent_workers=False, - drop_last=False, - collate_fn=collate_fn, - ) + test_loader = DataLoader( + test_dataset, + sampler=DistributedSampler(test_dataset), + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + persistent_workers=False, + drop_last=False, + collate_fn=collate_fn, + ) - logger.info("Built {} dataset for evaluation.".format(dataset_name)) + logger.info("Built {} dataset for evaluation.".format(dataset_name)) - if task_name == "regression": - # TODO: This doesn't work atm - test_evaluator = RegEvaluator(cfg, test_loader, exp_dir, device) - else: - test_evaluator = SegEvaluator(cfg, test_loader, exp_dir, device) + if task_name == "regression": + # TODO: This doesn't work atm + test_evaluator = RegEvaluator(cfg, test_loader, exp_dir, device) + else: + test_evaluator = SegEvaluator(cfg, test_loader, exp_dir, device) + + model_ckpt_path = os.path.join( + exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) + ) + test_evaluator.evaluate(model, "best model", model_ckpt_path) - model_ckpt_path = os.path.join( - exp_dir, next(f for f in os.listdir(exp_dir) if f.endswith("_best.pth")) - ) - test_evaluator.evaluate(model, "best model", model_ckpt_path) if cfg.use_wandb and cfg.rank == 0: wandb.finish() diff --git a/utils/compute_data_statistics.py b/utils/compute_data_statistics.py index d99c439c..887130a7 100644 --- a/utils/compute_data_statistics.py +++ b/utils/compute_data_statistics.py @@ -1,10 +1,38 @@ +from utils import registry import utils.registry import omegaconf import numpy as np import tqdm -import datasets import torch -import pprint + + +class RunningStats: + def __init__(self, stats_dim): + self.n = 0 + self.sum = torch.zeros(stats_dim) + self.sum_2 = torch.zeros(stats_dim) + + self.min = 10e10 * torch.ones(stats_dim) + self.max = -10e10 * torch.ones(stats_dim) + + def update(self, x, reduce_dim): + self.n += np.prod([x.shape[i] for i in reduce_dim]) + self.sum += torch.sum(x, reduce_dim) + self.sum_2 += torch.sum(x**2, reduce_dim) + + x_min = torch.amin(x, reduce_dim) + x_max = torch.amax(x, reduce_dim) + self.min = torch.min(self.min, x_min) + self.max = torch.max(self.max, x_max) + + def finalize(self): + return { + "mean": self.sum / self.n, + "std": torch.sqrt(self.sum_2 / self.n - (self.sum / self.n) ** 2), + "min": self.min, + "max": self.max, + } + configs = [ "configs/datasets/mados.yaml", @@ -17,28 +45,23 @@ dataset = utils.registry.DATASET_REGISTRY.get(cfg.dataset_name) dataset.download(cfg, silent=False) train_dataset, val_dataset, test_dataset = dataset.get_splits(cfg) + stats = {} + data = train_dataset.__getitem__(0) - min = {} - max = {} + # STATS initialization + stats = {} + for modality, img in data["image"].items(): + n_channels = img.shape[0] + stats[modality] = RunningStats(n_channels) + # STATS computation for data in tqdm.tqdm(train_dataset, desc=cfg.dataset_name): - for modality, img in data['image'].items(): - dims = [i for i in range(len(img.shape))] - dims.pop(-3) - img = torch.nan_to_num(img) - local_max = torch.amax(img, dim=dims) - local_min = torch.amin(img, dim=dims) - - if min.get(modality, None) is None: - print(modality, local_min.shape) - min[modality] = torch.full_like(local_min, 10e10) - max[modality] = torch.full_like(local_max, -10e10) - - min[modality] = torch.minimum(min[modality], local_min) - max[modality] = torch.maximum(max[modality], local_max) - - pprint.pp(cfg.dataset_name) - pprint.pp({ - "max": max, - "min": min - }) + for modality, img in data["image"].items(): + reduce_dim = list(range(1, img.ndim)) + stats[modality].update(img, reduce_dim) + + # STATS finalization + for modality, stat in stats.items(): + print(modality) + print(stat.finalize()) + print("_" * 100) diff --git a/utils/configs.py b/utils/configs.py index 2b0f8ebf..468ae597 100644 --- a/utils/configs.py +++ b/utils/configs.py @@ -11,10 +11,23 @@ def load_configs(parser:argparse.ArgumentParser) -> OmegaConf: cli_provided, cli_defaults = omegaconf_from_argparse(parser) all_cli = OmegaConf.merge(cli_defaults, cli_provided) + # print(all_cli) + if all_cli.eval_dir: # Just load the dumped config file if we are evaluating eval_config_path = pathlib.Path(all_cli.eval_dir) / 'configs' file_cfg = OmegaConf.load(eval_config_path/"config.yaml") + + # if all_cli.config: + # to do the test on different datasets (with same number of classes) + bootstrap_cfg = OmegaConf.merge(cli_defaults, file_cfg, cli_provided) + if bootstrap_cfg.augmentation_config_path is not None: + augmentation_cfg = OmegaConf.load(bootstrap_cfg.augmentation_config_path) + file_cfg["augmentation"] = augmentation_cfg + if bootstrap_cfg.augmentation_config_path is not None: + dataset_cfg = OmegaConf.load(bootstrap_cfg.dataset_config_path) + file_cfg["dataset"] = dataset_cfg + cfg = OmegaConf.merge(cli_defaults, file_cfg, cli_provided) elif all_cli.config: