Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #23 Process bands based on what encoder needs #35

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions configs/foundation_models/croma_joint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
encoder_name: CROMA_JOINT_Encoder
foundation_model_name: CROMA_JOINT_large
encoder_weights: ./pretrained_models/CROMA_large.pt
download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt
temporal_input: False


num_layers: 24
embed_dim: 1024
input_size: 120 # the paper uses 120

encoder_model_args:
size: 'large'
image_resolution: 120

input_bands:
optical:
- B1
- B2
- B3
- B4
- B5
- B6
- B7
- B8
- B8A
- B9
- B11
- B12
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ input_bands:
- B9
- B11
- B12
sar:
- VV
- VH

output_layers:
- 3
Expand Down
25 changes: 25 additions & 0 deletions configs/foundation_models/croma_sar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
encoder_name: CROMA_SAR_Encoder
foundation_model_name: CROMA_SAR_large
encoder_weights: ./pretrained_models/CROMA_large.pt
download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt
temporal_input: False


num_layers: 24
embed_dim: 1024
input_size: 120 # the paper uses 120

encoder_model_args:
size: 'large'
image_resolution: 120

input_bands:
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ input_bands:
- B10
- B11
- B12
sar:
- VV
- VH


output_layers:
- 3
Expand Down
30 changes: 30 additions & 0 deletions configs/foundation_models/ssl4eo_mae_sar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
encoder_name: SSL4EO_MAE_SAR_Encoder
foundation_model_name: ssl4eo_mae_vit_small_patch16
encoder_weights: ./pretrained_models/B2_vits16_mae_ep99.pth
download_url: https://huggingface.co/wangyi111/SSL4EO-S12/resolve/main/B2_vits16_mae_ep99.pth

temporal_input: False

num_layers: 12
embed_dim: 384
input_size: 224

encoder_model_args:
img_size: 224
in_chans: 2
embed_dim: 384
patch_size: 16
num_heads: 6
depth: 12
mlp_ratio: 4

input_bands:
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
55 changes: 30 additions & 25 deletions engine/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, dataset, cfg):
# Either use unly these, or only the input arguments.
self.root_cfg = cfg
self.dataset_cfg = cfg.dataset
self.encoder_cfg = cfg.encoder
self.root_path = cfg.dataset.root_path
self.classes = cfg.dataset.classes
self.class_num = len(self.classes)
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self, dataset, cfg, local_cfg):
)
# TO DO: other modalities

for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
new_stats = self.preprocessor[modality].preprocess_band_statistics(
self.data_mean[modality],
self.data_std[modality],
Expand All @@ -123,7 +124,8 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
data["image"][k] = self.preprocessor[k](v)
if k in self.encoder_cfg.input_bands:
data["image"][k] = self.preprocessor[k](v)

data["target"] = data["target"].long()
return data
Expand All @@ -138,7 +140,8 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
data["image"][k] = self.preprocessor[k](v)
if k in self.encoder_cfg.input_bands:
data["image"][k] = self.preprocessor[k](v)

data["target"] = data["target"].float()
return data
Expand Down Expand Up @@ -293,7 +296,7 @@ def __getitem__(self, index):
tiled_data = {"image": {}, "target": None}
tiled_data["image"] = {}
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
tiled_data["image"][k] = v[
..., h : h + self.output_size, w : w + self.output_size
].clone()
Expand Down Expand Up @@ -340,12 +343,12 @@ def __getitem__(self, index):
data = self.dataset[index]
if random.random() < self.ud_probability:
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = torch.fliplr(v)
data["target"] = torch.fliplr(data["target"])
if random.random() < self.lr_probability:
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = torch.flipud(v)
data["target"] = torch.flipud(data["target"])
return data
Expand All @@ -361,7 +364,7 @@ def __init__(self, dataset, cfg, local_cfg):
def __getitem__(self, index):
data = self.dataset[index]
if random.random() < self.probability:
for k, v in data["image"].items():
for k, v in data["image"].items() and k in self.encoder_cfg.input_bands:
if k not in self.ignore_modalities:
data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range))
return data
Expand All @@ -374,7 +377,7 @@ def __init__(self, dataset, cfg, local_cfg):
self.data_mean_tensors = {}
self.data_std_tensors = {}
# Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use
for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
self.data_mean_tensors[modality] = torch.tensor(
self.data_mean[modality]
).reshape((-1, 1, 1, 1))
Expand All @@ -384,7 +387,7 @@ def __init__(self, dataset, cfg, local_cfg):

def __getitem__(self, index):
data = self.dataset[index]
for modality in data["image"]:
for modality in self.encoder_cfg.input_bands:
if modality not in self.ignore_modalities:
data["image"][modality] = (
data["image"][modality] - self.data_mean_tensors[modality]
Expand All @@ -401,7 +404,7 @@ def __init__(self, dataset, cfg, local_cfg):
self.data_max_tensors = {}
self.min = local_cfg.min
self.max = local_cfg.max
for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
self.data_min_tensors[modality] = torch.tensor(
self.data_min[modality]
).reshape((-1, 1, 1, 1))
Expand All @@ -411,7 +414,7 @@ def __init__(self, dataset, cfg, local_cfg):

def __getitem__(self, index):
data = self.dataset[index]
for modality in data["image"]:
for modality in self.encoder_cfg.input_bands:
if modality not in self.ignore_modalities:
data["image"][modality] = (
(data["image"][modality] - self.data_min_tensors[modality])
Expand Down Expand Up @@ -460,20 +463,22 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
brightness = random.uniform(-self.brightness, self.brightness)
if random.random() < self.br_probability:
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_brightness(
data["image"][k], brightness, self.clip
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
brightness = random.uniform(-self.brightness, self.brightness)
if random.random() < self.br_probability:
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_brightness(
data["image"][k], brightness, self.clip
)

for k, v in data["image"].items():
if random.random() < self.ct_probability:
contrast = random.uniform(1 - self.contrast, 1 + self.contrast)
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_contrast(
data["image"][k], contrast, self.clip
)
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
if random.random() < self.ct_probability:
contrast = random.uniform(1 - self.contrast, 1 + self.contrast)
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_contrast(
data["image"][k], contrast, self.clip
)

return data

Expand All @@ -487,7 +492,7 @@ def __init__(self, dataset, cfg, local_cfg):
def __getitem__(self, index):
data = self.dataset[index]
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.Resize(self.size)(v)

if data["target"].ndim == 2:
Expand Down Expand Up @@ -531,7 +536,7 @@ def __getitem__(self, index):
output_size=(self.size, self.size),
)
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.functional.crop(v, i, j, h, w)
data["target"] = T.functional.crop(data["target"], i, j, h, w)

Expand Down Expand Up @@ -584,7 +589,7 @@ def __getitem__(self, index):
i, j, h, w = crop_candidates[crop_idx]

for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.functional.crop(v, i, j, h, w)
data["target"] = T.functional.crop(data["target"], i, j, h, w)

Expand Down