diff --git a/configs/augmentations/segmentation_oversampling.yaml b/configs/augmentations/segmentation_oversampling.yaml new file mode 100644 index 00000000..cf0ee149 --- /dev/null +++ b/configs/augmentations/segmentation_oversampling.yaml @@ -0,0 +1,11 @@ +train: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ImportanceRandomCropToEncoder: ~ + # RandomFlip: + # ud_probability: 0.3 + # lr_probability: 0.3 +test: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + Tile: ~ \ No newline at end of file diff --git a/configs/segmentors/reg_upernet_mt.yaml b/configs/segmentors/reg_upernet_mt.yaml index a35de0b7..6802be33 100644 --- a/configs/segmentors/reg_upernet_mt.yaml +++ b/configs/segmentors/reg_upernet_mt.yaml @@ -1,5 +1,6 @@ segmentor_name: MTUPerNet task_name: regression +binary: False multi_temporal_strategy: linear # time_frames: 2 #task_model_args: diff --git a/configs/segmentors/unet_binary.yaml b/configs/segmentors/unet_binary.yaml index e96bfd92..3e75fb98 100644 --- a/configs/segmentors/unet_binary.yaml +++ b/configs/segmentors/unet_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UNet task_name: semantic-segmentation +binary: True # time_frames: 1 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/unet_cd_binary.yaml b/configs/segmentors/unet_cd_binary.yaml index 26b1d5e1..9d178ea3 100644 --- a/configs/segmentors/unet_cd_binary.yaml +++ b/configs/segmentors/unet_cd_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UNetCD task_name: change-detection +binary: True #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet.yaml b/configs/segmentors/upernet.yaml index 135a2f3d..012b765f 100644 --- a/configs/segmentors/upernet.yaml +++ b/configs/segmentors/upernet.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNet task_name: semantic-segmentation +binary: False # time_frames: 2 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/upernet_binary.yaml b/configs/segmentors/upernet_binary.yaml index 5e52c00d..74b76cb4 100644 --- a/configs/segmentors/upernet_binary.yaml +++ b/configs/segmentors/upernet_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNet task_name: semantic-segmentation +binary: True # time_frames: 2 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/upernet_cd.yaml b/configs/segmentors/upernet_cd.yaml index 6504eb57..12b026da 100644 --- a/configs/segmentors/upernet_cd.yaml +++ b/configs/segmentors/upernet_cd.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNetCD task_name: change-detection +binary: False #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet_cd_binary.yaml b/configs/segmentors/upernet_cd_binary.yaml index 64a28081..0773647c 100644 --- a/configs/segmentors/upernet_cd_binary.yaml +++ b/configs/segmentors/upernet_cd_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNetCD task_name: change-detection +binary: True #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet_mt.yaml b/configs/segmentors/upernet_mt.yaml index f2e47963..ae5582c0 100644 --- a/configs/segmentors/upernet_mt.yaml +++ b/configs/segmentors/upernet_mt.yaml @@ -1,5 +1,6 @@ segmentor_name: MTUPerNet task_name: semantic-segmentation +binary: False multi_temporal_strategy: linear # time_frames: 2 #task_model_args: diff --git a/datasets/spacenet7.py b/datasets/spacenet7.py index 5a91b221..f4b82b8d 100644 --- a/datasets/spacenet7.py +++ b/datasets/spacenet7.py @@ -109,6 +109,7 @@ def __init__(self, cfg, split): self.data_std = cfg['data_std'] self.classes = cfg['classes'] self.class_num = len(self.classes) + self.distribution = cfg['distribution'] self.split = split if split == 'train': @@ -216,12 +217,16 @@ def __getitem__(self, index): image = torch.from_numpy(image) target = torch.from_numpy(target) + weight = torch.empty(target.shape) + for i, freq in enumerate(self.distribution): + weight[target == i] = 1 - freq output = { 'image': { 'optical': image, }, 'target': target, + 'weight': weight, 'metadata': {} } @@ -243,7 +248,8 @@ def __init__(self, cfg, split, eval_mode): self.T = cfg['multi_temporal'] assert self.T > 1 self.eval_mode = eval_mode - self.items = list(self.aoi_ids) + self.multiplier = 1 if eval_mode else 100 # TODO: get this from config + self.items = self.multiplier * list(self.aoi_ids) def __len__(self): return len(self.items) @@ -275,12 +281,16 @@ def __getitem__(self, index): year_t2, month_t2 = timestamps[-1]['year'], timestamps[-1]['month'] target = self.load_change_label(aoi_id, year_t1, month_t1, year_t2, month_t2) target = torch.from_numpy(target) + weight = torch.empty(target.shape) + for i, freq in enumerate(self.distribution): + weight[target == i] = 1 - freq output = { 'image': { 'optical': image, }, 'target': target, + 'weight': weight, 'metadata': {} } diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index ac866494..40e9d02f 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -546,3 +546,55 @@ def __init__(self, dataset, cfg, local_cfg): local_cfg.size = cfg.encoder.input_size super().__init__(dataset, cfg, local_cfg) + + +@AUGMENTER_REGISTRY.register() +class ImportanceRandomCrop(BaseAugment): + def __init__(self, dataset, cfg, local_cfg): + super().__init__(dataset, cfg, local_cfg) + self.size = local_cfg.size + self.padding = getattr(local_cfg, "padding", None) + self.pad_if_needed = getattr(local_cfg, "pad_if_needed", False) + self.fill = getattr(local_cfg, "fill", 0) + self.padding_mode = getattr(local_cfg, "padding_mode", "constant") + self.n_crops = 10 # TODO: put this one in config + + def __getitem__(self, index): + data = self.dataset[index] + + # dataset needs to provide a weighting layer + assert 'weight' in data.keys() + + # candidates for random crop + crop_candidates, crop_weights = [], [] + for _ in range(self.n_crops): + i, j, h, w = T.RandomCrop.get_params( + data["image"][ + list(data["image"].keys())[0] + ], # Use the first image to determine parameters + output_size=(self.size, self.size), + ) + crop_candidates.append((i, j, h, w)) + + crop_weight = T.functional.crop(data['weight'], i, j, h, w) + crop_weights.append(torch.sum(crop_weight).item()) + + crop_weights = np.array(crop_weights) / sum(crop_weights) + crop_idx = np.random.choice(self.n_crops, p=crop_weights) + i, j, h, w = crop_candidates[crop_idx] + + for k, v in data["image"].items(): + if k not in self.ignore_modalities: + data["image"][k] = T.functional.crop(v, i, j, h, w) + data["target"] = T.functional.crop(data["target"], i, j, h, w) + + return data + + +@AUGMENTER_REGISTRY.register() +class ImportanceRandomCropToEncoder(ImportanceRandomCrop): + def __init__(self, dataset, cfg, local_cfg): + if not local_cfg: + local_cfg = omegaconf.OmegaConf.create() + local_cfg.size = cfg.encoder.input_size + super().__init__(dataset, cfg, local_cfg) \ No newline at end of file diff --git a/engine/evaluator.py b/engine/evaluator.py index 66dd6ba5..a2e6bd4d 100644 --- a/engine/evaluator.py +++ b/engine/evaluator.py @@ -63,7 +63,10 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): target = target.to(self.device) logits = model(image, output_shape=target.shape[-2:]) - pred = torch.argmax(logits, dim=1) + if logits.shape[1] == 1: + pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1) + else: + pred = torch.argmax(logits, dim=1) valid_mask = target != -1 pred, target = pred[valid_mask], target[valid_mask] count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes ** 2) diff --git a/engine/trainer.py b/engine/trainer.py index b6558f27..384ad236 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -224,14 +224,21 @@ def compute_loss(self, logits, target): @torch.no_grad() def compute_logging_metrics(self, logits, target): # logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear') - pred = torch.argmax(logits, dim=1, keepdim=True) + num_classes = logits.shape[1] + if num_classes == 1: + pred = (torch.sigmoid(logits) > 0.5).type(torch.int64) + else: + pred = torch.argmax(logits, dim=1, keepdim=True) target = target.unsqueeze(1) ignore_mask = target == -1 target[ignore_mask] = 0 - ignore_mask = ignore_mask.expand(-1, logits.shape[1], -1, -1) + ignore_mask = ignore_mask.expand(-1, num_classes if num_classes > 1 else 2, -1, -1) - binary_pred = torch.zeros(logits.shape, dtype=bool, device=self.device) - binary_target = torch.zeros(logits.shape, dtype=bool, device=self.device) + dims = list(logits.shape) + if num_classes == 1: + dims[1] = 2 + binary_pred = torch.zeros(dims, dtype=bool, device=self.device) + binary_target = torch.zeros(dims, dtype=bool, device=self.device) binary_pred.scatter_(dim=1, index=pred, src=torch.ones_like(binary_pred)) binary_target.scatter_(dim=1, index=target, src=torch.ones_like(binary_target)) binary_pred[ignore_mask] = 0 diff --git a/segmentors/unet.py b/segmentors/unet.py index 32bed0fd..eb5463ef 100644 --- a/segmentors/unet.py +++ b/segmentors/unet.py @@ -25,7 +25,7 @@ def __init__(self, args, cfg, encoder): self.align_corners = False - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] self.topology = encoder.topology self.decoder = Decoder(self.topology) @@ -57,7 +57,7 @@ def __init__(self, args, cfg, encoder): self.align_corners = False - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] self.topology = encoder.topology self.decoder = Decoder(self.topology) diff --git a/segmentors/upernet.py b/segmentors/upernet.py index 2a1613f4..5312e131 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -48,7 +48,7 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): self.in_channels = [cfg['in_channels'] for _ in range(4)] self.channels = cfg['channels'] - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] # PSP Module self.psp_modules = PPM( diff --git a/utils/losses.py b/utils/losses.py index 772c50a9..bd6a6ab0 100644 --- a/utils/losses.py +++ b/utils/losses.py @@ -26,20 +26,28 @@ def __init__(self, cfg): self.ignore_index = cfg["ignore_index"] def forward(self, logits, target): - # Convert logits to probabilities using softmax - probs = F.softmax(logits, dim=1) - num_classes = logits.shape[1] + + # Convert logits to probabilities using softmax or sigmoid + if num_classes == 1: + probs = torch.sigmoid(logits) + else: + probs = F.softmax(logits, dim=1) + mask = (target != self.ignore_index) #mask_expand = mask.unsqueeze(1).expand_as(probs) - target_temp = target.clone() - target_temp[~mask] = 0 - - target_one_hot = F.one_hot(target_temp, num_classes=num_classes).permute(0, 3, 1, 2).float() - target_one_hot = target_one_hot * mask.unsqueeze(1).float() - - intersection = torch.sum(probs * target_one_hot, dim=(2, 3)) - union = torch.sum(probs + target_one_hot, dim=(2, 3)) + target = target.clone() + target[~mask] = 0 + + if num_classes == 1: + target = target.unsqueeze(1) + else: + target = F.one_hot(target, num_classes=num_classes) + target = target.permute(0, 3, 1, 2) + + target = target.float() * mask.unsqueeze(1).float() + intersection = torch.sum(probs * target, dim=(2, 3)) + union = torch.sum(probs + target, dim=(2, 3)) dice_score = (2. * intersection + 1e-6) / (union + 1e-6) # dice_loss = 1 - dice_score.mean(dim=1).mean()