diff --git a/configs/datasets/dynamicen.yaml b/configs/datasets/dynamicen.yaml new file mode 100644 index 00000000..190c10c5 --- /dev/null +++ b/configs/datasets/dynamicen.yaml @@ -0,0 +1,52 @@ +dataset_name: DynamicEarthNet +root_path: ./data/DynamicEarthNet +download_url: None +auto_download: False + +img_size: 1024 +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 6 +classes: + - impervious surface + - agriculture + - forest & other vegetation + - wetlands + - soil + - water +distribution: + - 0.071 + - 0.103 + - 0.449 + - 0.07 + - 0.28 + - 0.08 + +# data stats +bands: + optical: + - B2 + - B3 + - B4 + - B5 + +data_mean: + optical: + - 1042.59240722656 + - 915.618408203125 + - 671.260559082031 + - 2605.20922851562 +data_std: + optical: + - 957.958435058593 + - 715.548767089843 + - 596.943908691406 + - 1059.90319824218 + +data_min: + optical: [0.0000, 0.0000, 0.0000, 0.0000] +data_max: + optical: [0.0000, 0.0000, 0.0000, 0.0000] \ No newline at end of file diff --git a/configs/datasets/spacenet7.yaml b/configs/datasets/spacenet7.yaml index 4269acb4..4cbd0b75 100644 --- a/configs/datasets/spacenet7.yaml +++ b/configs/datasets/spacenet7.yaml @@ -3,7 +3,12 @@ root_path: ./data/spacenet7 download_url: https://drive.google.com/uc?id=1BADSEjxYKFZZlM-tEkRUfHvHi5XdaVV9 auto_download: True -img_size: 1024 +img_size: 256 # the image size is used to tile the SpaceNet 7 images (1024, 1024) +domain_shift: False +# parameters for within-scene splits (no domain shift) +i_split: 768 +j_split: 512 + multi_temporal: False multi_modal: False diff --git a/configs/datasets/spacenet7_domainshift.yaml b/configs/datasets/spacenet7_domainshift.yaml new file mode 100644 index 00000000..209b755c --- /dev/null +++ b/configs/datasets/spacenet7_domainshift.yaml @@ -0,0 +1,44 @@ +dataset_name: SN7MAPPING +root_path: ./data/spacenet7 +download_url: https://drive.google.com/uc?id=1BADSEjxYKFZZlM-tEkRUfHvHi5XdaVV9 +auto_download: True + +img_size: 256 # the image size is used to tile the SpaceNet 7 images (1024, 1024) +domain_shift: True +# parameters for within-scene splits (no domain shift) +i_split: 768 +j_split: 512 + +multi_temporal: False +multi_modal: False + +# classes +ignore_index: -1 +num_classes: 2 +classes: + - Background + - Building +distribution: + - 0.92530769 + - 0.07469231 + +# data stats +bands: + optical: + - B4 # Band 1 (Red) + - B3 # Band 2 (Green) + - B2 # Band 3 (Blue) +data_mean: + optical: + - 121.826 + - 106.52838 + - 78.372116 +data_std: + optical: + - 56.717068 + - 44.517075 + - 40.451515 +data_min: + optical: [0.0, 0.0, 0.0] +data_max: + optical: [255.0, 255.0, 255.0] \ No newline at end of file diff --git a/configs/datasets/spacenet7cd.yaml b/configs/datasets/spacenet7cd.yaml index 6a593422..5b873338 100644 --- a/configs/datasets/spacenet7cd.yaml +++ b/configs/datasets/spacenet7cd.yaml @@ -3,10 +3,17 @@ root_path: ./data/spacenet7 download_url: https://drive.google.com/uc?id=1BADSEjxYKFZZlM-tEkRUfHvHi5XdaVV9 auto_download: True -img_size: 1024 +img_size: 256 # the image size is used to tile the SpaceNet 7 images (1024, 1024) +domain_shift: False +# parameters for within-scene splits (no domain shift) +i_split: 768 +j_split: 512 + multi_temporal: 2 multi_modal: False +dataset_multiplier: 1 # multiplies sample in dataset during training + # classes ignore_index: -1 num_classes: 2 diff --git a/configs/foundation_models/croma_joint.yaml b/configs/foundation_models/croma_joint.yaml new file mode 100644 index 00000000..0fe9e544 --- /dev/null +++ b/configs/foundation_models/croma_joint.yaml @@ -0,0 +1,38 @@ +encoder_name: CROMA_JOINT_Encoder +foundation_model_name: CROMA_JOINT_large +encoder_weights: ./pretrained_models/CROMA_large.pt +download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt +temporal_input: False + + +num_layers: 24 +embed_dim: 1024 +input_size: 120 # the paper uses 120 + +encoder_model_args: + size: 'large' + image_resolution: 120 + +input_bands: + optical: + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B8 + - B8A + - B9 + - B11 + - B12 + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/configs/foundation_models/croma.yaml b/configs/foundation_models/croma_optical.yaml similarity index 95% rename from configs/foundation_models/croma.yaml rename to configs/foundation_models/croma_optical.yaml index 98a7740c..04c9e82d 100644 --- a/configs/foundation_models/croma.yaml +++ b/configs/foundation_models/croma_optical.yaml @@ -27,9 +27,6 @@ input_bands: - B9 - B11 - B12 - sar: - - VV - - VH output_layers: - 3 diff --git a/configs/foundation_models/croma_sar.yaml b/configs/foundation_models/croma_sar.yaml new file mode 100644 index 00000000..babca453 --- /dev/null +++ b/configs/foundation_models/croma_sar.yaml @@ -0,0 +1,25 @@ +encoder_name: CROMA_SAR_Encoder +foundation_model_name: CROMA_SAR_large +encoder_weights: ./pretrained_models/CROMA_large.pt +download_url: https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt +temporal_input: False + + +num_layers: 24 +embed_dim: 1024 +input_size: 120 # the paper uses 120 + +encoder_model_args: + size: 'large' + image_resolution: 120 + +input_bands: + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/configs/foundation_models/ssl4eo_mae.yaml b/configs/foundation_models/ssl4eo_mae_optical.yaml similarity index 96% rename from configs/foundation_models/ssl4eo_mae.yaml rename to configs/foundation_models/ssl4eo_mae_optical.yaml index 1294e0d5..6c985fc2 100644 --- a/configs/foundation_models/ssl4eo_mae.yaml +++ b/configs/foundation_models/ssl4eo_mae_optical.yaml @@ -33,9 +33,7 @@ input_bands: - B10 - B11 - B12 - sar: - - VV - - VH + output_layers: - 3 diff --git a/configs/foundation_models/ssl4eo_mae_sar.yaml b/configs/foundation_models/ssl4eo_mae_sar.yaml new file mode 100644 index 00000000..d463f0fe --- /dev/null +++ b/configs/foundation_models/ssl4eo_mae_sar.yaml @@ -0,0 +1,30 @@ +encoder_name: SSL4EO_MAE_SAR_Encoder +foundation_model_name: ssl4eo_mae_vit_small_patch16 +encoder_weights: ./pretrained_models/B2_vits16_mae_ep99.pth +download_url: https://huggingface.co/wangyi111/SSL4EO-S12/resolve/main/B2_vits16_mae_ep99.pth + +temporal_input: False + +num_layers: 12 +embed_dim: 384 +input_size: 224 + +encoder_model_args: + img_size: 224 + in_chans: 2 + embed_dim: 384 + patch_size: 16 + num_heads: 6 + depth: 12 + mlp_ratio: 4 + +input_bands: + sar: + - VV + - VH + +output_layers: + - 3 + - 5 + - 7 + - 11 diff --git a/configs/segmentors/unet_cd_binary.yaml b/configs/segmentors/siamconcunet_binary.yaml similarity index 94% rename from configs/segmentors/unet_cd_binary.yaml rename to configs/segmentors/siamconcunet_binary.yaml index 9d178ea3..80318e82 100644 --- a/configs/segmentors/unet_cd_binary.yaml +++ b/configs/segmentors/siamconcunet_binary.yaml @@ -1,4 +1,4 @@ -segmentor_name: UNetCD +segmentor_name: SiamConcUNet task_name: change-detection binary: True #task_model_args: diff --git a/configs/segmentors/upernet_cd_binary.yaml b/configs/segmentors/siamconcupernet_binary.yaml similarity index 93% rename from configs/segmentors/upernet_cd_binary.yaml rename to configs/segmentors/siamconcupernet_binary.yaml index 0773647c..445fb04d 100644 --- a/configs/segmentors/upernet_cd_binary.yaml +++ b/configs/segmentors/siamconcupernet_binary.yaml @@ -1,4 +1,4 @@ -segmentor_name: UPerNetCD +segmentor_name: SiamConcUPerNet task_name: change-detection binary: True #task_model_args: diff --git a/configs/segmentors/siamdiffunet_binary.yaml b/configs/segmentors/siamdiffunet_binary.yaml new file mode 100644 index 00000000..0b2ff043 --- /dev/null +++ b/configs/segmentors/siamdiffunet_binary.yaml @@ -0,0 +1,23 @@ +segmentor_name: SiamDiffUNet +task_name: change-detection +binary: True + #task_model_args: + #num_frames: 1 + #mt_strategy: "ltae" #activated only when if num_frames > 1 +#num_classes - task parameter passed from the dataset config +#wave_list - task parameter passed from the dataset config + +channels: 512 + +loss: + loss_name: DICELoss + ignore_index: -1 + +optimizer: + optimizer_name: AdamW + lr: 0.0001 + weight_decay: 0.05 + +scheduler: + scheduler_name: MultiStepLR + lr_milestones: [0.6, 0.9] \ No newline at end of file diff --git a/configs/segmentors/upernet_cd.yaml b/configs/segmentors/siamdiffupernet.yaml similarity index 93% rename from configs/segmentors/upernet_cd.yaml rename to configs/segmentors/siamdiffupernet.yaml index 12b026da..0b7cc8d0 100644 --- a/configs/segmentors/upernet_cd.yaml +++ b/configs/segmentors/siamdiffupernet.yaml @@ -1,4 +1,4 @@ -segmentor_name: UPerNetCD +segmentor_name: SiamDiffUPerNet task_name: change-detection binary: False #task_model_args: diff --git a/configs/segmentors/siamdiffupernet_binary.yaml b/configs/segmentors/siamdiffupernet_binary.yaml new file mode 100644 index 00000000..59766a6f --- /dev/null +++ b/configs/segmentors/siamdiffupernet_binary.yaml @@ -0,0 +1,26 @@ +segmentor_name: SiamDiffUPerNet +task_name: change-detection +binary: True + #task_model_args: + #num_frames: 1 + #mt_strategy: "ltae" #activated only when if num_frames > 1 +#num_classes - task parameter passed from the dataset config +#wave_list - task parameter passed from the dataset config + +channels: 512 + +loss: + loss_name: DICELoss + ignore_index: -1 + +optimizer: + optimizer_name: AdamW + lr: 0.0001 + weight_decay: 0.05 + +scheduler: + scheduler_name: MultiStepLR + lr_milestones: [0.6, 0.9] + + + diff --git a/datasets/__init__.py b/datasets/__init__.py index f921292e..026a3837 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -6,4 +6,5 @@ from .croptypemapping import CropTypeMappingSouthSudan from .ai4smallfarms import AI4SmallFarms from .spacenet7 import SN7MAPPING, SN7CD -from .fivebillionpixels import FiveBillionPixels \ No newline at end of file +from .fivebillionpixels import FiveBillionPixels +from .utae_dynamicen import DynamicEarthNet diff --git a/datasets/spacenet7.py b/datasets/spacenet7.py index f4b82b8d..c0af09f4 100644 --- a/datasets/spacenet7.py +++ b/datasets/spacenet7.py @@ -98,28 +98,28 @@ class AbstractSN7(torch.utils.data.Dataset): - def __init__(self, cfg, split): + def __init__(self, cfg): super().__init__() self.root_path = Path(cfg['root_path']) metadata_file = self.root_path / 'metadata_train.json' with open(metadata_file, 'r') as f: self.metadata = json.load(f) + self.img_size = 1024 # size of the SpaceNet 7 images + # unpacking config + self.tile_size = cfg['img_size'] # size used for tiling the images + assert self.img_size % self.tile_size == 0 + self.data_mean = cfg['data_mean'] self.data_std = cfg['data_std'] self.classes = cfg['classes'] - self.class_num = len(self.classes) self.distribution = cfg['distribution'] - self.split = split + self.domain_shift = cfg['domain_shift'] + self.i_split = cfg['i_split'] + self.j_split = cfg['j_split'] - if split == 'train': - self.aoi_ids = SN7_TRAIN - elif split == 'val': - self.aoi_ids = SN7_VAL - elif split == 'test': - self.aoi_ids = SN7_TEST - else: - raise Exception('Invalid split') + self.class_num = len(self.classes) + self.sn7_aois = list(SN7_TRAIN) + list(SN7_VAL) + list(SN7_TEST) @abstractmethod def __getitem__(self, index: int) -> dict: @@ -149,8 +149,8 @@ def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray: def load_change_label(self, aoi_id: str, year_t1: int, month_t1: int, year_t2: int, month_t2) -> np.ndarray: building_t1 = self.load_building_label(aoi_id, year_t1, month_t1) building_t2 = self.load_building_label(aoi_id, year_t2, month_t2) - change = building_t1 != building_t2 - return change.astype(np.float32) + change = np.not_equal(building_t1, building_t2) + return change.astype(np.int64) @staticmethod def get_band(path): @@ -195,14 +195,68 @@ def download(dataset_config: dict, silent=False): @DATASET_REGISTRY.register() class SN7MAPPING(AbstractSN7): def __init__(self, cfg, split): - super().__init__(cfg, split) + super().__init__(cfg) + self.split = split self.items = [] - # adding timestamps (only if label exists and not masked) for each AOI - for aoi_id in self.aoi_ids: - timestamps = list(self.metadata[aoi_id]) - timestamps = [ts for ts in timestamps if not ts['mask'] and ts['label']] - self.items.extend(timestamps) + + if self.domain_shift: # split by AOI ids + if split == 'train': + self.aoi_ids = list(SN7_TRAIN) + elif split == 'val': + self.aoi_ids = list(SN7_VAL) + elif split == 'test': + self.aoi_ids = list(SN7_TEST) + else: + raise Exception('Unkown split') + + # adding timestamps (only if label exists and not masked) for each AOI + for aoi_id in self.aoi_ids: + timestamps = list(self.metadata[aoi_id]) + for timestamp in timestamps: + if not timestamp['mask'] and timestamp['label']: + item = { + 'aoi_id': timestamp['aoi_id'], + 'year': timestamp['year'], + 'month': timestamp['month'], + } + # tiling the timestamps + for i in range(0, self.img_size, self.tile_size): + for j in range(0, self.img_size, self.tile_size): + item['i'] = i + item['j'] = j + self.items.append(dict(item)) + + else: # within-scenes split + assert self.i_split % self.tile_size == 0 and self.j_split % self.tile_size == 0 + assert self.tile_size <= self.i_split and self.tile_size <= self.j_split + self.aoi_ids = list(self.sn7_aois) + for aoi_id in self.aoi_ids: + timestamps = list(self.metadata[aoi_id]) + for timestamp in timestamps: + if not timestamp['mask'] and timestamp['label']: + item = { + 'aoi_id': timestamp['aoi_id'], + 'year': timestamp['year'], + 'month': timestamp['month'], + } + if split == 'train': + i_min, i_max = 0, self.i_split + j_min, j_max = 0, self.img_size + elif split == 'val': + i_min, i_max = self.i_split, self.img_size + j_min, j_max = 0, self.j_split + elif split == 'test': + i_min, i_max = self.i_split, self.img_size + j_min, j_max = self.j_split, self.img_size + else: + raise Exception('Unkown split') + # tiling the timestamps + for i in range(i_min, i_max, self.tile_size): + for j in range(j_min, j_max, self.tile_size): + item['i'] = i + item['j'] = j + self.items.append(dict(item)) def __len__(self): return len(self.items) @@ -215,6 +269,11 @@ def __getitem__(self, index): image = self.load_planet_mosaic(aoi_id, year, month) target = self.load_building_label(aoi_id, year, month) + # cut to tile + i, j = item['i'], item['j'] + image = image[:, i:i + self.tile_size, j:j + self.tile_size] + target = target[i:i + self.tile_size, j:j + self.tile_size] + image = torch.from_numpy(image) target = torch.from_numpy(target) weight = torch.empty(target.shape) @@ -243,19 +302,69 @@ def get_splits(dataset_config): @DATASET_REGISTRY.register() class SN7CD(AbstractSN7): def __init__(self, cfg, split, eval_mode): - super().__init__(cfg, split) + super().__init__(cfg) self.T = cfg['multi_temporal'] assert self.T > 1 + self.eval_mode = eval_mode - self.multiplier = 1 if eval_mode else 100 # TODO: get this from config - self.items = self.multiplier * list(self.aoi_ids) + self.multiplier = 1 if eval_mode else cfg['dataset_multiplier'] + + self.split = split + self.items = [] + + if self.domain_shift: # split by AOI ids + if split == 'train': + self.aoi_ids = list(SN7_TRAIN) + elif split == 'val': + self.aoi_ids = list(SN7_VAL) + elif split == 'test': + self.aoi_ids = list(SN7_TEST) + else: + raise Exception('Unkown split') + + # adding timestamps (only if label exists and not masked) for each AOI + for aoi_id in self.aoi_ids: + item = { 'aoi_id': aoi_id } + # tiling the timestamps + for i in range(0, self.img_size, self.tile_size): + for j in range(0, self.img_size, self.tile_size): + item['i'] = i + item['j'] = j + self.items.append(dict(item)) + + else: # within-scenes split + assert self.i_split % self.tile_size == 0 and self.j_split % self.tile_size == 0 + assert self.tile_size <= self.i_split and self.tile_size <= self.j_split + self.aoi_ids = list(self.sn7_aois) + for aoi_id in self.aoi_ids: + item = { 'aoi_id': aoi_id } + if split == 'train': + i_min, i_max = 0, self.i_split + j_min, j_max = 0, self.img_size + elif split == 'val': + i_min, i_max = self.i_split, self.img_size + j_min, j_max = 0, self.j_split + elif split == 'test': + i_min, i_max = self.i_split, self.img_size + j_min, j_max = self.j_split, self.img_size + else: + raise Exception('Unkown split') + # tiling the timestamps + for i in range(i_min, i_max, self.tile_size): + for j in range(j_min, j_max, self.tile_size): + item['i'] = i + item['j'] = j + self.items.append(dict(item)) + + self.items = self.multiplier * list(self.items) def __len__(self): return len(self.items) def __getitem__(self, index): - aoi_id = self.items[index] + item = self.items[index] + aoi_id = item['aoi_id'] # determine timestamps for t1 and t2 (random for train and first-last for eval) timestamps = [ts for ts in self.metadata[aoi_id] if not ts['mask'] and ts['label']] @@ -281,6 +390,13 @@ def __getitem__(self, index): year_t2, month_t2 = timestamps[-1]['year'], timestamps[-1]['month'] target = self.load_change_label(aoi_id, year_t1, month_t1, year_t2, month_t2) target = torch.from_numpy(target) + + # cut to tile + i, j = item['i'], item['j'] + image = image[:, :, i:i + self.tile_size, j:j + self.tile_size] + target = target[i:i + self.tile_size, j:j + self.tile_size] + + # weight for oversampling weight = torch.empty(target.shape) for i, freq in enumerate(self.distribution): weight[target == i] = 1 - freq diff --git a/datasets/utae_dynamicen.py b/datasets/utae_dynamicen.py new file mode 100644 index 00000000..3747a643 --- /dev/null +++ b/datasets/utae_dynamicen.py @@ -0,0 +1,168 @@ +import os +import numpy as np +import rasterio +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from datetime import datetime +import torchvision.transforms.functional as TF +import cv2 + +import random +from PIL import Image + +from utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register() +class DynamicEarthNet(Dataset): + #def __init__(self, root, mode, type, reference_date="2018-01-01", crop_size=512, num_classes=6, ignore_index=-1): + def __init__(self, cfg, split, is_train=True): + """ + Args: + root: the root of the folder which contains planet imagery and labels + mode: train/val/test -- selects the splits + type: single/weekly/daily -- selects the time-period you want to use + reference_date: for positional encoding defaults:2018-01-01 + crop_size: crop size default:1024x1024 + num_classes: for DynamicEarthNet numclasses: 6 + ignore_index: default:-1 + """ + self.root_path = cfg['root_path'] + self.data_mean = cfg['data_mean'] + self.data_std = cfg['data_std'] + self.classes = cfg['classes'] + self.ignore_index = cfg['ignore_index'] + self.class_num = len(self.classes) + self.split = split + self.is_train = is_train + + self.mode = 'single' + + self.files = [] + + reference_date = "2018-01-01" + self.reference_date = datetime(*map(int, reference_date.split("-"))) + + self.set_files() + + def set_files(self): + self.file_list = os.path.join(self.root_path, "dynnet_training_splits", f"{self.split}" + ".txt") + + file_list = [line.rstrip().split(' ') for line in tuple(open(self.file_list, "r"))] + #for + self.files, self.labels, self.year_months = list(zip(*file_list)) + self.files = [f.replace('/reprocess-cropped/UTM-24000/', '/planet/') for f in self.files] + + if self.mode == 'daily': + self.all_days = list(range(len(self.files))) + + for i in range(len(self.files)): + self.planet, self.day = [], [] + date_count = 0 + for _, _, infiles in os.walk(os.path.join(self.root_path, self.files[i][1:])): + for infile in sorted(infiles): + if infile.startswith(self.year_months[i]): + self.planet.append(os.path.join(self.files[i], infile)) + self.day.append((datetime(int(str(infile.split('.')[0])[:4]), int(str(infile.split('.')[0][5:7])), + int(str(infile.split('.')[0])[8:])) - self.reference_date).days) + date_count += 1 + self.all_days[i] = list(zip(self.planet, self.day)) + self.all_days[i].insert(0, date_count) + + else: + self.planet, self.day = [], [] + if self.mode == 'weekly': + self.dates = ['01', '05', '10', '15', '20', '25'] + elif self.mode == 'single': + self.dates = ['01'] + + for i, year_month in enumerate(self.year_months): + for date in self.dates: + curr_date = year_month + '-' + date + self.planet.append(os.path.join(self.files[i], curr_date + '.tif')) + self.day.append((datetime(int(str(curr_date)[:4]), int(str(curr_date[5:7])), + int(str(curr_date)[8:])) - self.reference_date).days) + self.planet_day = list(zip(*[iter(self.planet)] * len(self.dates), *[iter(self.day)] * len(self.dates))) + + + def load_data(self, index): + cur_images, cur_dates = [], [] + if self.mode == 'daily': + for i in range(1, self.all_days[index][0]+1): + img = rasterio.open(os.path.join(self.root_path, self.all_days[index][i][0][1:])) + red = img.read(3) + green = img.read(2) + blue = img.read(1) + nir = img.read(4) + image = np.dstack((red, green, blue, nir)) + cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ + cur_dates.append(self.all_days[index][i][1]) + + image_stack = np.concatenate(cur_images, axis=0) + dates = torch.from_numpy(np.array(cur_dates, dtype=np.int32)) + label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) + label = label.read() + mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) + + for i in range(self.class_num + 1): + if i == 6: + mask[label[i, :, :] == 255] = -1 + else: + mask[label[i, :, :] == 255] = i + + return (image_stack, dates), mask + + else: + for i in range(len(self.dates)): + # read .tif + img = rasterio.open(os.path.join(self.root_path, self.planet_day[index][i][1:])) + red = img.read(3) + green = img.read(2) + blue = img.read(1) + nir = img.read(4) + image = np.dstack((red, green, blue, nir)) + cur_images.append(np.expand_dims(np.asarray(image, dtype=np.float32), axis=0)) # np.array already\ + image_stack = np.concatenate(cur_images, axis=0) + dates = torch.from_numpy(np.array(self.planet_day[index][len(self.dates):], dtype=np.int32)) + label = rasterio.open(os.path.join(self.root_path, self.labels[index][1:])) + label = label.read() + mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int32) + + for i in range(self.class_num + 1): + if i == 6: + mask[label[i, :, :] == 255] = -1 + else: + mask[label[i, :, :] == 255] = i + + return (image_stack, dates), mask + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + (images, dates), label = self.load_data(index) + + images = torch.from_numpy(images).permute(3, 0, 1, 2)#.transpose(0, 1) + label = torch.from_numpy(np.array(label, dtype=np.int32)).long() + + output = { + 'image': { + 'optical': images, + }, + 'target': label, + 'metadata': {} + } + + return output + #return {'img': images, 'label': label, 'meta': dates} + + @staticmethod + def get_splits(dataset_config): + dataset_train = DynamicEarthNet(cfg=dataset_config, split="train") + dataset_val = DynamicEarthNet(cfg=dataset_config, split="val") + dataset_test = DynamicEarthNet(cfg=dataset_config, split="test") + return dataset_train, dataset_val, dataset_test + + @staticmethod + def download(dataset_config: dict, silent=False): + pass diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index 40e9d02f..7eacfc8d 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -17,7 +17,7 @@ def get_collate_fn(cfg: omegaconf.DictConfig) -> Callable: - modalities = cfg.dataset.bands.keys() + modalities = cfg.encoder.input_bands.keys() def collate_fn( batch: dict[dict[str, torch.Tensor]] @@ -72,6 +72,7 @@ def __init__(self, dataset, cfg): # Either use unly these, or only the input arguments. self.root_cfg = cfg self.dataset_cfg = cfg.dataset + self.encoder_cfg = cfg.encoder self.root_path = cfg.dataset.root_path self.classes = cfg.dataset.classes self.class_num = len(self.classes) @@ -106,7 +107,7 @@ def __init__(self, dataset, cfg, local_cfg): ) # TO DO: other modalities - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: new_stats = self.preprocessor[modality].preprocess_band_statistics( self.data_mean[modality], self.data_std[modality], @@ -123,7 +124,8 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - data["image"][k] = self.preprocessor[k](v) + if k in self.encoder_cfg.input_bands: + data["image"][k] = self.preprocessor[k](v) data["target"] = data["target"].long() return data @@ -138,7 +140,8 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - data["image"][k] = self.preprocessor[k](v) + if k in self.encoder_cfg.input_bands: + data["image"][k] = self.preprocessor[k](v) data["target"] = data["target"].float() return data @@ -293,7 +296,7 @@ def __getitem__(self, index): tiled_data = {"image": {}, "target": None} tiled_data["image"] = {} for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: tiled_data["image"][k] = v[ ..., h : h + self.output_size, w : w + self.output_size ].clone() @@ -340,12 +343,12 @@ def __getitem__(self, index): data = self.dataset[index] if random.random() < self.ud_probability: for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = torch.fliplr(v) data["target"] = torch.fliplr(data["target"]) if random.random() < self.lr_probability: for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = torch.flipud(v) data["target"] = torch.flipud(data["target"]) return data @@ -361,7 +364,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] if random.random() < self.probability: - for k, v in data["image"].items(): + for k, v in data["image"].items() and k in self.encoder_cfg.input_bands: if k not in self.ignore_modalities: data["image"][k] = torch.pow(v, random.uniform(*self.gamma_range)) return data @@ -374,7 +377,7 @@ def __init__(self, dataset, cfg, local_cfg): self.data_mean_tensors = {} self.data_std_tensors = {} # Bands is a dict of {modality:[b1, b2, ...], ...} so it's keys are the modalaities in use - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: self.data_mean_tensors[modality] = torch.tensor( self.data_mean[modality] ).reshape((-1, 1, 1, 1)) @@ -384,7 +387,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] - for modality in data["image"]: + for modality in self.encoder_cfg.input_bands: if modality not in self.ignore_modalities: data["image"][modality] = ( data["image"][modality] - self.data_mean_tensors[modality] @@ -401,7 +404,7 @@ def __init__(self, dataset, cfg, local_cfg): self.data_max_tensors = {} self.min = local_cfg.min self.max = local_cfg.max - for modality in self.dataset_cfg.bands: + for modality in self.encoder_cfg.input_bands: self.data_min_tensors[modality] = torch.tensor( self.data_min[modality] ).reshape((-1, 1, 1, 1)) @@ -411,7 +414,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] - for modality in data["image"]: + for modality in self.encoder_cfg.input_bands: if modality not in self.ignore_modalities: data["image"][modality] = ( (data["image"][modality] - self.data_min_tensors[modality]) @@ -460,20 +463,22 @@ def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - brightness = random.uniform(-self.brightness, self.brightness) - if random.random() < self.br_probability: - if k not in self.ignore_modalities: - data["image"][k] = self.adjust_brightness( - data["image"][k], brightness, self.clip + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: + brightness = random.uniform(-self.brightness, self.brightness) + if random.random() < self.br_probability: + if k not in self.ignore_modalities: + data["image"][k] = self.adjust_brightness( + data["image"][k], brightness, self.clip ) for k, v in data["image"].items(): - if random.random() < self.ct_probability: - contrast = random.uniform(1 - self.contrast, 1 + self.contrast) - if k not in self.ignore_modalities: - data["image"][k] = self.adjust_contrast( - data["image"][k], contrast, self.clip - ) + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: + if random.random() < self.ct_probability: + contrast = random.uniform(1 - self.contrast, 1 + self.contrast) + if k not in self.ignore_modalities: + data["image"][k] = self.adjust_contrast( + data["image"][k], contrast, self.clip + ) return data @@ -487,7 +492,7 @@ def __init__(self, dataset, cfg, local_cfg): def __getitem__(self, index): data = self.dataset[index] for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.Resize(self.size)(v) if data["target"].ndim == 2: @@ -531,7 +536,7 @@ def __getitem__(self, index): output_size=(self.size, self.size), ) for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.functional.crop(v, i, j, h, w) data["target"] = T.functional.crop(data["target"], i, j, h, w) @@ -584,7 +589,7 @@ def __getitem__(self, index): i, j, h, w = crop_candidates[crop_idx] for k, v in data["image"].items(): - if k not in self.ignore_modalities: + if k not in self.ignore_modalities and k in self.encoder_cfg.input_bands: data["image"][k] = T.functional.crop(v, i, j, h, w) data["target"] = T.functional.crop(data["target"], i, j, h, w) @@ -597,4 +602,4 @@ def __init__(self, dataset, cfg, local_cfg): if not local_cfg: local_cfg = omegaconf.OmegaConf.create() local_cfg.size = cfg.encoder.input_size - super().__init__(dataset, cfg, local_cfg) \ No newline at end of file + super().__init__(dataset, cfg, local_cfg) diff --git a/run.py b/run.py index 7be343c6..21c11cb7 100644 --- a/run.py +++ b/run.py @@ -171,7 +171,10 @@ def main(): project="geofm-bench", name=exp_name, config=OmegaConf.to_container(cfg, resolve=True), + resume='allow', + id=cfg.get('wandb_run_id'), ) + cfg['wandb_run_id'] = wandb.run.id # get datasets dataset = DATASET_REGISTRY.get(cfg.dataset.dataset_name) @@ -255,7 +258,8 @@ def main(): batch_size=cfg.batch_size, # cfg.dataset["batch"], num_workers=cfg.num_workers, pin_memory=True, - persistent_workers=True, + # persistent_workers=True causes memory leak + persistent_workers=False, worker_init_fn=seed_worker, generator=get_generator(cfg.seed), drop_last=True, @@ -268,7 +272,7 @@ def main(): num_workers=cfg.num_workers, pin_memory=True, persistent_workers=False, - # worker_init_fn=seed_worker, + worker_init_fn=seed_worker, # generator=g, drop_last=False, collate_fn=collate_fn, diff --git a/segmentors/__init__.py b/segmentors/__init__.py index f99a8b00..3db4ed44 100644 --- a/segmentors/__init__.py +++ b/segmentors/__init__.py @@ -1,5 +1,5 @@ -from .upernet import UPerNet, UPerNetCD, MTUPerNet +from .upernet import UPerNet, SiamDiffUPerNet, SiamConcUPerNet, MTUPerNet from .ltae import LTAE2d -from .unet import UNet, UNetCD +from .unet import UNet, SiamDiffUNet, SiamConcUNet -all = ['UPerNet', 'UPerNetCD', 'MTUPerNet', 'UNet', 'UNetCD'] \ No newline at end of file +all = ['UPerNet', 'SiamDiffUPerNet', 'SiamConcUPerNet', 'MTUPerNet', 'UNet', 'SiamDiffUNet', 'SiamConcatUNet'] \ No newline at end of file diff --git a/segmentors/unet.py b/segmentors/unet.py index eb5463ef..dc7ec5bc 100644 --- a/segmentors/unet.py +++ b/segmentors/unet.py @@ -8,6 +8,7 @@ from typing import Sequence from utils.registry import SEGMENTOR_REGISTRY + @SEGMENTOR_REGISTRY.register() class UNet(nn.Module): """ @@ -40,16 +41,14 @@ def forward(self, img, output_shape=None): return output -@SEGMENTOR_REGISTRY.register() -class UNetCD(nn.Module): +class SiamUNet(nn.Module): """ """ - def __init__(self, args, cfg, encoder): + def __init__(self, args, cfg, encoder, strategy): super().__init__() # self.frozen_backbone = frozen_backbone - self.model_name = 'UNetCD' self.encoder = encoder self.finetune = args.finetune @@ -58,8 +57,14 @@ def __init__(self, args, cfg, encoder): self.align_corners = False self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] - self.topology = encoder.topology - + self.strategy = strategy + if self.strategy == 'diff': + self.topology = encoder.topology + elif self.strategy == 'concat': + self.topology = [2 * features for features in encoder.topology] + else: + raise NotImplementedError + self.decoder = Decoder(self.topology) self.conv_seg = OutConv(self.topology[0], self.num_classes) @@ -72,12 +77,32 @@ def forward(self, img, output_shape=None): feat1 = self.encoder(img1) feat2= self.encoder(img2) - feat = [f2 - f1 for f2, f1 in zip(feat1, feat2)] + if self.strategy == 'diff': + feat = [f2 - f1 for f1, f2 in zip(feat1, feat2)] + elif self.strategy == 'concat': + feat = [torch.concat((f1, f2), dim=1) for f1, f2 in zip(feat1, feat2)] + else: + raise NotImplementedError + feat = self.decoder(feat) output = self.conv_seg(feat) return output +@SEGMENTOR_REGISTRY.register() +class SiamDiffUNet(SiamUNet): + # Siamese UNet for change detection with feature differencing strategy + def __init__(self, args, cfg, encoder): + super().__init__(args, cfg, encoder, 'diff') + + +@SEGMENTOR_REGISTRY.register() +class SiamConcUNet(SiamUNet): + # Siamese UNet for change detection with feature concatenation strategy + def __init__(self, args, cfg, encoder): + super().__init__(args, cfg, encoder, 'concat') + + class Decoder(nn.Module): def __init__(self, topology: Sequence[int]): super(Decoder, self).__init__() diff --git a/segmentors/upernet.py b/segmentors/upernet.py index d519ec28..425d8c98 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -19,7 +19,7 @@ class UPerNet(nn.Module): Module applied on the last feature. Default: (1, 2, 3, 6). """ - def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): + def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6), feature_multiplier: int = 1): super().__init__() # self.frozen_backbone = frozen_backbone @@ -27,6 +27,7 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): self.model_name = 'UPerNet' self.encoder = encoder self.finetune = args.finetune + self.feature_multiplier = feature_multiplier if not self.finetune: for param in self.encoder.parameters(): @@ -36,11 +37,11 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): # for param in self.backbone.parameters(): # param.requires_grad = False - self.neck = Feature2Pyramid(embed_dim=cfg['in_channels'], rescales=[4, 2, 1, 0.5]) + self.neck = Feature2Pyramid(embed_dim=cfg['in_channels'] * feature_multiplier, rescales=[4, 2, 1, 0.5]) self.align_corners = False - self.in_channels = [cfg['in_channels'] for _ in range(4)] + self.in_channels = [cfg['in_channels'] * feature_multiplier for _ in range(4)] self.channels = cfg['channels'] self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] @@ -156,8 +157,6 @@ def _forward_feature(self, inputs): feats = self.fpn_bottleneck(fpn_outs) return feats - - def forward(self, img, output_shape=None): """Forward function.""" # if self.freezed_backbone: @@ -238,12 +237,21 @@ def forward(self, img, output_shape=None): output = F.interpolate(output, size=output_shape, mode='bilinear') return output - -@SEGMENTOR_REGISTRY.register() -class UPerNetCD(UPerNet): - def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): - super().__init__(args, cfg, encoder, pool_scales=(1, 2, 3, 6)) + +class SiamUPerNet(UPerNet): + def __init__(self, args, cfg, encoder, pool_scales, strategy): + + self.strategy = strategy + if self.strategy == 'diff': + self.feature_multiplier = 1 + elif self.strategy == 'concat': + self.feature_multiplier = 2 + else: + raise NotImplementedError + + super().__init__(args, cfg, encoder, pool_scales=pool_scales, feature_multiplier=self.feature_multiplier) + def forward(self, img, output_shape=None): """Forward function for change detection.""" @@ -266,13 +274,17 @@ def forward(self, img, output_shape=None): else: feats = self.encoder(img) - feat1 = [] - feat2 = [] + feat1, feat2 = [], [] for i in range(len(feats)): feat1.append(feats[i][:,:,0, :, :].squeeze(2)) feat2.append(feats[i][:,:,1, :, :].squeeze(2)) - feat = [f2 - f1 for f2, f1 in zip(feat1, feat2)] + if self.strategy == 'diff': + feat = [f2 - f1 for f1, f2 in zip(feat1, feat2)] + elif self.strategy == 'concat': + feat = [torch.concat((f1, f2), dim=1) for f1, f2 in zip(feat1, feat2)] + else: + raise NotImplementedError feat = self.neck(feat) feat = self._forward_feature(feat) @@ -285,6 +297,21 @@ def forward(self, img, output_shape=None): return output + +@SEGMENTOR_REGISTRY.register() +class SiamDiffUPerNet(SiamUPerNet): + # Siamese UPerNet for change detection with feature differencing strategy + def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): + super().__init__(args, cfg, encoder, pool_scales, 'diff') + + +@SEGMENTOR_REGISTRY.register() +class SiamConcUPerNet(SiamUPerNet): + # Siamese UPerNet for change detection with feature concatenation strategy + def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): + super().__init__(args, cfg, encoder, pool_scales, 'concat') + + class PPM(nn.ModuleList): """Pooling Pyramid Module used in PSPNet.