Skip to content

Commit

Permalink
Process bands based on what encoder needs (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
yurujaja authored Sep 12, 2024
1 parent 56a107c commit 35e600a
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 31 deletions.
38 changes: 38 additions & 0 deletions configs/foundation_models/croma_joint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
encoder_name: CROMA_JOINT_Encoder
foundation_model_name: CROMA_JOINT_large
encoder_weights: ./pretrained_models/CROMA_large.pt
download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt
temporal_input: False


num_layers: 24
embed_dim: 1024
input_size: 120 # the paper uses 120

encoder_model_args:
size: 'large'
image_resolution: 120

input_bands:
optical:
- B1
- B2
- B3
- B4
- B5
- B6
- B7
- B8
- B8A
- B9
- B11
- B12
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ input_bands:
- B9
- B11
- B12
sar:
- VV
- VH

output_layers:
- 3
Expand Down
25 changes: 25 additions & 0 deletions configs/foundation_models/croma_sar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
encoder_name: CROMA_SAR_Encoder
foundation_model_name: CROMA_SAR_large
encoder_weights: ./pretrained_models/CROMA_large.pt
download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt
temporal_input: False


num_layers: 24
embed_dim: 1024
input_size: 120 # the paper uses 120

encoder_model_args:
size: 'large'
image_resolution: 120

input_bands:
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ input_bands:
- B10
- B11
- B12
sar:
- VV
- VH


output_layers:
- 3
Expand Down
30 changes: 30 additions & 0 deletions configs/foundation_models/ssl4eo_mae_sar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
encoder_name: SSL4EO_MAE_SAR_Encoder
foundation_model_name: ssl4eo_mae_vit_small_patch16
encoder_weights: ./pretrained_models/B2_vits16_mae_ep99.pth
download_url: https://huggingface.co/wangyi111/SSL4EO-S12/resolve/main/B2_vits16_mae_ep99.pth

temporal_input: False

num_layers: 12
embed_dim: 384
input_size: 224

encoder_model_args:
img_size: 224
in_chans: 2
embed_dim: 384
patch_size: 16
num_heads: 6
depth: 12
mlp_ratio: 4

input_bands:
sar:
- VV
- VH

output_layers:
- 3
- 5
- 7
- 11
55 changes: 30 additions & 25 deletions engine/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, dataset, cfg):
# Either use unly these, or only the input arguments.
self.root_cfg = cfg
self.dataset_cfg = cfg.dataset
self.encoder_cfg = cfg.encoder
self.root_path = cfg.dataset.root_path
self.classes = cfg.dataset.classes
self.class_num = len(self.classes)
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self, dataset, cfg, local_cfg):
)
# TO DO: other modalities

for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
new_stats = self.preprocessor[modality].preprocess_band_statistics(
self.data_mean[modality],
self.data_std[modality],
Expand All @@ -123,7 +124,8 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
data["image"][k] = self.preprocessor[k](v)
if k in self.encoder_cfg.input_bands:
data["image"][k] = self.preprocessor[k](v)

data["target"] = data["target"].long()
return data
Expand All @@ -138,7 +140,8 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
data["image"][k] = self.preprocessor[k](v)
if k in self.encoder_cfg.input_bands:
data["image"][k] = self.preprocessor[k](v)

data["target"] = data["target"].float()
return data
Expand Down Expand Up @@ -293,7 +296,7 @@ def __getitem__(self, index):
tiled_data = {"image": {}, "target": None}
tiled_data["image"] = {}
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
tiled_data["image"][k] = v[
..., h : h + self.output_size, w : w + self.output_size
].clone()
Expand Down Expand Up @@ -340,12 +343,12 @@ def __getitem__(self, index):
data = self.dataset[index]
if random.random() < self.ud_probability:
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = torch.fliplr(v)
data["target"] = torch.fliplr(data["target"])
if random.random() < self.lr_probability:
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = torch.flipud(v)
data["target"] = torch.flipud(data["target"])
return data
Expand All @@ -361,7 +364,7 @@ def __init__(self, dataset, cfg, local_cfg):
def __getitem__(self, index):
data = self.dataset[index]
if random.random() < self.probability:
for k, v in data["image"].items():
for k, v in data["image"].items() and k in self.encoder_cfg.input_bands:
if k not in self.ignore_modalities:
data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range))
return data
Expand All @@ -374,7 +377,7 @@ def __init__(self, dataset, cfg, local_cfg):
self.data_mean_tensors = {}
self.data_std_tensors = {}
# Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use
for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
self.data_mean_tensors[modality] = torch.tensor(
self.data_mean[modality]
).reshape((-1, 1, 1, 1))
Expand All @@ -384,7 +387,7 @@ def __init__(self, dataset, cfg, local_cfg):

def __getitem__(self, index):
data = self.dataset[index]
for modality in data["image"]:
for modality in self.encoder_cfg.input_bands:
if modality not in self.ignore_modalities:
data["image"][modality] = (
data["image"][modality] - self.data_mean_tensors[modality]
Expand All @@ -401,7 +404,7 @@ def __init__(self, dataset, cfg, local_cfg):
self.data_max_tensors = {}
self.min = local_cfg.min
self.max = local_cfg.max
for modality in self.dataset_cfg.bands:
for modality in self.encoder_cfg.input_bands:
self.data_min_tensors[modality] = torch.tensor(
self.data_min[modality]
).reshape((-1, 1, 1, 1))
Expand All @@ -411,7 +414,7 @@ def __init__(self, dataset, cfg, local_cfg):

def __getitem__(self, index):
data = self.dataset[index]
for modality in data["image"]:
for modality in self.encoder_cfg.input_bands:
if modality not in self.ignore_modalities:
data["image"][modality] = (
(data["image"][modality] - self.data_min_tensors[modality])
Expand Down Expand Up @@ -460,20 +463,22 @@ def __getitem__(self, index):
data = self.dataset[index]

for k, v in data["image"].items():
brightness = random.uniform(-self.brightness, self.brightness)
if random.random() < self.br_probability:
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_brightness(
data["image"][k], brightness, self.clip
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
brightness = random.uniform(-self.brightness, self.brightness)
if random.random() < self.br_probability:
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_brightness(
data["image"][k], brightness, self.clip
)

for k, v in data["image"].items():
if random.random() < self.ct_probability:
contrast = random.uniform(1 - self.contrast, 1 + self.contrast)
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_contrast(
data["image"][k], contrast, self.clip
)
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
if random.random() < self.ct_probability:
contrast = random.uniform(1 - self.contrast, 1 + self.contrast)
if k not in self.ignore_modalities:
data["image"][k] = self.adjust_contrast(
data["image"][k], contrast, self.clip
)

return data

Expand All @@ -487,7 +492,7 @@ def __init__(self, dataset, cfg, local_cfg):
def __getitem__(self, index):
data = self.dataset[index]
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.Resize(self.size)(v)

if data["target"].ndim == 2:
Expand Down Expand Up @@ -531,7 +536,7 @@ def __getitem__(self, index):
output_size=(self.size, self.size),
)
for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.functional.crop(v, i, j, h, w)
data["target"] = T.functional.crop(data["target"], i, j, h, w)

Expand Down Expand Up @@ -584,7 +589,7 @@ def __getitem__(self, index):
i, j, h, w = crop_candidates[crop_idx]

for k, v in data["image"].items():
if k not in self.ignore_modalities:
if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands:
data["image"][k] = T.functional.crop(v, i, j, h, w)
data["target"] = T.functional.crop(data["target"], i, j, h, w)

Expand Down

0 comments on commit 35e600a

Please sign in to comment.