diff --git a/configs/foundation_models/croma_joint.yaml b/configs/foundation_models/croma_joint.yaml new file mode 100644 index 00000000..0fe9e544 --- /dev/null +++ b/configs/foundation_models/croma_joint.yaml @@ -0,0 +1,38 @@ +encoder_name: CROMA_JOINT_Encoder +foundation_model_name: CROMA_JOINT_large +encoder_weights: ./pretrained_models/CROMA_large.pt +download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt +temporal_input: False + + +num_layers: 24 +embed_dim: 1024 +input_size: 120 # the paper uses 120 + +encoder_model_args: + size: 'large' + image_resolution: 120 + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/configs/foundation_models/croma.yaml b/configs/foundation_models/croma_optical.yaml similarity index 95% rename from configs/foundation_models/croma.yaml rename to configs/foundation_models/croma_optical.yaml index 98a7740c..04c9e82d 100644 --- a/configs/foundation_models/croma.yaml +++ b/configs/foundation_models/croma_optical.yaml @@ -27,9 +27,6 @@ input_bands: - B9 - B11 - B12 - sar: - - VV - - VH output_layers: - 3 diff --git a/configs/foundation_models/croma_sar.yaml b/configs/foundation_models/croma_sar.yaml new file mode 100644 index 00000000..babca453 --- /dev/null +++ b/configs/foundation_models/croma_sar.yaml @@ -0,0 +1,25 @@ +encoder_name: CROMA_SAR_Encoder +foundation_model_name: CROMA_SAR_large +encoder_weights: ./pretrained_models/CROMA_large.pt +download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt +temporal_input: False + + +num_layers: 24 +embed_dim: 1024 +input_size: 120 # the paper uses 120 + +encoder_model_args: + size: 'large' + image_resolution: 120 + +input_bands: + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/configs/foundation_models/ssl4eo_mae.yaml b/configs/foundation_models/ssl4eo_mae_optical.yaml similarity index 96% rename from configs/foundation_models/ssl4eo_mae.yaml rename to configs/foundation_models/ssl4eo_mae_optical.yaml index 1294e0d5..6c985fc2 100644 --- a/configs/foundation_models/ssl4eo_mae.yaml +++ b/configs/foundation_models/ssl4eo_mae_optical.yaml @@ -33,9 +33,7 @@ input_bands: - B10 - B11 - B12 - sar: - - VV - - VH + output_layers: - 3 diff --git a/configs/foundation_models/ssl4eo_mae_sar.yaml b/configs/foundation_models/ssl4eo_mae_sar.yaml new file mode 100644 index 00000000..d463f0fe --- /dev/null +++ b/configs/foundation_models/ssl4eo_mae_sar.yaml @@ -0,0 +1,30 @@ +encoder_name: SSL4EO_MAE_SAR_Encoder +foundation_model_name: ssl4eo_mae_vit_small_patch16 +encoder_weights: ./pretrained_models/B2_vits16_mae_ep99.pth +download_url: https://huggingface.co/wangyi111/SSL4EO-S12/resolve/main/B2_vits16_mae_ep99.pth + +temporal_input: False + +num_layers: 12 +embed_dim: 384 +input_size: 224 + +encoder_model_args: + img_size: 224 + in_chans: 2 + embed_dim: 384 + patch_size: 16 + num_heads: 6 + depth: 12 + mlp_ratio: 4 + +input_bands: + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index 40e9d02f..b85fe0bd 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -72,6 +72,7 @@ def __init__(self, dataset, cfg): # Either use unly these, or only the input arguments. self.root_cfg = cfg self.dataset_cfg = cfg.dataset + self.encoder_cfg = cfg.encoder self.root_path = cfg.dataset.root_path self.classes = cfg.dataset.classes self.class_num = len(self.classes) @@ -106,7 +107,7 @@ def __init__(self, dataset, cfg, local_cfg): ) # TO DO: other modalities - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: new_stats = self.preprocessor[modality].preprocess_band_statistics( self.data_mean[modality], self.data_std[modality], @@ -123,7 +124,8 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - data["image"][k] = self.preprocessor[k](v) + if k in self.encoder_cfg.input_bands: + data["image"][k] = self.preprocessor[k](v) data["target"] = data["target"].long() return data @@ -138,7 +140,8 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - data["image"][k] = self.preprocessor[k](v) + if k in self.encoder_cfg.input_bands: + data["image"][k] = self.preprocessor[k](v) data["target"] = data["target"].float() return data @@ -293,7 +296,7 @@ def __getitem__(self, index): tiled_data = {"image": {}, "target": None} tiled_data["image"] = {} for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: tiled_data["image"][k] = v[ ..., h : h + self.output_size, w : w + self.output_size ].clone() @@ -340,12 +343,12 @@ def __getitem__(self, index): data = self.dataset[index] if random.random() < self.ud_probability: for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = torch.fliplr(v) data["target"] = torch.fliplr(data["target"]) if random.random() < self.lr_probability: for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = torch.flipud(v) data["target"] = torch.flipud(data["target"]) return data @@ -361,7 +364,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] if random.random() < self.probability: - for k, v in data["image"].items(): + for k, v in data["image"].items() and k in self.encoder_cfg.input_bands: if k not in self.ignore_modalities: data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range)) return data @@ -374,7 +377,7 @@ def __init__(self, dataset, cfg, local_cfg): self.data_mean_tensors = {} self.data_std_tensors = {} # Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: self.data_mean_tensors[modality] = torch.tensor( self.data_mean[modality] ).reshape((-1, 1, 1, 1)) @@ -384,7 +387,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] - for modality in data["image"]: + for modality in self.encoder_cfg.input_bands: if modality not in self.ignore_modalities: data["image"][modality] = ( data["image"][modality] - self.data_mean_tensors[modality] @@ -401,7 +404,7 @@ def __init__(self, dataset, cfg, local_cfg): self.data_max_tensors = {} self.min = local_cfg.min self.max = local_cfg.max - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: self.data_min_tensors[modality] = torch.tensor( self.data_min[modality] ).reshape((-1, 1, 1, 1)) @@ -411,7 +414,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] - for modality in data["image"]: + for modality in self.encoder_cfg.input_bands: if modality not in self.ignore_modalities: data["image"][modality] = ( (data["image"][modality] - self.data_min_tensors[modality]) @@ -460,20 +463,22 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - brightness = random.uniform(-self.brightness, self.brightness) - if random.random() < self.br_probability: - if k not in self.ignore_modalities: - data["image"][k] = self.adjust_brightness( - data["image"][k], brightness, self.clip + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: + brightness = random.uniform(-self.brightness, self.brightness) + if random.random() < self.br_probability: + if k not in self.ignore_modalities: + data["image"][k] = self.adjust_brightness( + data["image"][k], brightness, self.clip ) for k, v in data["image"].items(): - if random.random() < self.ct_probability: - contrast = random.uniform(1 - self.contrast, 1 + self.contrast) - if k not in self.ignore_modalities: - data["image"][k] = self.adjust_contrast( - data["image"][k], contrast, self.clip - ) + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: + if random.random() < self.ct_probability: + contrast = random.uniform(1 - self.contrast, 1 + self.contrast) + if k not in self.ignore_modalities: + data["image"][k] = self.adjust_contrast( + data["image"][k], contrast, self.clip + ) return data @@ -487,7 +492,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.Resize(self.size)(v) if data["target"].ndim == 2: @@ -531,7 +536,7 @@ def __getitem__(self, index): output_size=(self.size, self.size), ) for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.functional.crop(v, i, j, h, w) data["target"] = T.functional.crop(data["target"], i, j, h, w) @@ -584,7 +589,7 @@ def __getitem__(self, index): i, j, h, w = crop_candidates[crop_idx] for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.functional.crop(v, i, j, h, w) data["target"] = T.functional.crop(data["target"], i, j, h, w)