Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update satlasnet info #64

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes
39 changes: 27 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ While geospatial foundation models (GFMs) have proliferated rapidly, their evalu
To bridge this gap, we propose a standardized evaluation protocol that incorporates a wide-ranging selection of datasets, tasks, resolutions, and sensor types, establishing a robust and widely applicable benchmark for GFMs.


<img src="figures/geofmbenchmark.png" alt="PANGAEA: a diverse benchmark for geospatial foundation models" width="80%">
<img src=".github/geofmbenchmark.png" alt="PANGAEA: a diverse benchmark for geospatial foundation models" width="90%">


In this repo, you can find the code to benchmark GFMs. For the moment we included several GFMs that present different approach. We look forward to adding new models and datasets.
Expand Down Expand Up @@ -152,22 +152,35 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
```

#### Multi-Temporal Semantic Segmentation
- Multi-temporal decoder config (e.g. `configs/decoder/seg_upernet_mt_ltae.yaml` if you want to use `ltae` as a strategy to combine multi-temporal info) should be used.
- In addition, in the dataset config, indicate the number of time frames, e.g., `multi_temporal: 6`

Multi-temporal decoder config (e.g. `configs/decoder/seg_upernet_mt_ltae.yaml` if you want to use `ltae` as a strategy to combine multi-temporal info) should be used. e.g. Prithvi encoder on CropTypeMapping
In addition, in the dataset config, indicate the number of time frames, e.g., `multi_temporal: 6`

An example of using SSL4EO-DINO on CropTypeMapping is as below
```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
--config-name=train \
dataset=croptypemapping \
encoder=prithvi \
encoder=ssl4eo_dino \
decoder=seg_upernet_mt_ltae \
preprocessing=seg_resize \
criterion=cross_entropy \
task=segmentation
```

To overwrite parameters, please check the Single Temporal Semantic Segmentation example
To use SatlasNet encoder, the `configs/encoder/satlasnet_mi.yaml` is required
```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
--config-name=train \
dataset=croptypemapping \
encoder=satlasnet_mi \
decoder=seg_upernet_mt_ltae decoder.multi_temporal_strategy=null \
preprocessing=seg_resize \
criterion=cross_entropy \
task=segmentation
```


To overwrite parameters, please check the Single Temporal Semantic Segmentation example.

#### Change Detection

Expand All @@ -184,12 +197,12 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=change_detection
```

To overwrite parameters, please check the Single Temporal Semantic Segmentation example
To overwrite parameters, please check the Single Temporal Semantic Segmentation example.

#### Single Temporal Regression

The regression decoder (e.g. `configs/decoder/reg_upernet.yaml`) and the regression task (e.g. `configs/task/regression.yaml`) configs should be used.
e.g. Prithvi encoder on BioMassters
E.g. Prithvi encoder on BioMassters

```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
Expand All @@ -201,13 +214,14 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
criterion=cross_entropy \
task=regression
```

To overwrite parameters, please check the Single Temporal Semantic Segmentation example
To use SatlasNet encoder, the `configs/encoder/satlasnet_si.yaml` is required.
To overwrite parameters, please check the Single Temporal Semantic Segmentation example.

#### Multi-Temporal Regression

The multi-temporal regression decoder (e.g. `configs/decoder/reg_upernet_mt_ltae.yaml` or `configs/decoder/reg_upernet_mt_linear.yaml`) and the regression task (e.g. `configs/task/regression.yaml`) configs should be used.
e.g. Prithvi encoder on BioMassters

Take Prithvi encoder on BioMassters as example:

```
torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
Expand All @@ -220,7 +234,8 @@ torchrun --nnodes=1 --nproc_per_node=1 pangaea/run.py \
task=regression
```

To overwrite parameters, please check the Single Temporal Semantic Segmentation example
To use SatlasNet encoder, please refer to the multi-temporal semantic segmentation example.
To overwrite parameters, please check the Single Temporal Semantic Segmentation example.

### 💻 End-to-end Finetuning

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_target_: pangaea.encoders.satlasnet_encoder.SatlasNet_Encoder
encoder_weights: null
download_url: null
model_identifier: Sentinel2_SwinB_MI_MS #Sentinel2_SwinB_SI_RGB #
model_identifier: Sentinel2_SwinB_MI_MS # Multi-Image Multi-Spectral
fpn: False
input_size: 128
output_dim: 1024
Expand Down
25 changes: 25 additions & 0 deletions configs/encoder/satlasnet_si.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

_target_: pangaea.encoders.satlasnet_encoder.SatlasNet_Encoder
encoder_weights: null
download_url: null
model_identifier: Sentinel2_SwinB_SI_MS # Single Image Multi-Spectral
fpn: False
input_size: 128
output_dim: 1024

input_bands:
optical:
# - B4
# - B3
# - B2

#MS Bands
- B2
- B3
- B4
- B5
- B6
- B7
- B8
- B11
- B12
2 changes: 2 additions & 0 deletions pangaea/encoders/ssl4eo_mae_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
input_size: int,
input_bands: dict[str, list[str]],
output_layers: int | list[int],
download_url: str,
embed_dim: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
Expand All @@ -217,6 +218,7 @@ def __init__(
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
download_url=download_url,
)

self.model_name = "ssl4eo_mae_sar"
Expand Down