From 5b23ef6e1595e58b9724f7057b59ce168ed2b70a Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 09:12:24 +0200 Subject: [PATCH 1/9] first steps --- engine/trainer.py | 6 +++--- utils/losses.py | 30 +++++++++++++++++++----------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/engine/trainer.py b/engine/trainer.py index 4e329cd5..f57ef141 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -50,9 +50,9 @@ def train(self): #end_time = time.time() for epoch in range(self.start_epoch, self.epochs): # train the network for one epoch - if epoch % self.args.eval_interval == 0: - _, used_time = self.evaluator(self.model, f'epoch {epoch}') - self.training_stats['eval_time'].update(used_time) + # if epoch % self.args.eval_interval == 0: + # _, used_time = self.evaluator(self.model, f'epoch {epoch}') + # self.training_stats['eval_time'].update(used_time) self.logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler diff --git a/utils/losses.py b/utils/losses.py index 772c50a9..bd6a6ab0 100644 --- a/utils/losses.py +++ b/utils/losses.py @@ -26,20 +26,28 @@ def __init__(self, cfg): self.ignore_index = cfg["ignore_index"] def forward(self, logits, target): - # Convert logits to probabilities using softmax - probs = F.softmax(logits, dim=1) - num_classes = logits.shape[1] + + # Convert logits to probabilities using softmax or sigmoid + if num_classes == 1: + probs = torch.sigmoid(logits) + else: + probs = F.softmax(logits, dim=1) + mask = (target != self.ignore_index) #mask_expand = mask.unsqueeze(1).expand_as(probs) - target_temp = target.clone() - target_temp[~mask] = 0 - - target_one_hot = F.one_hot(target_temp, num_classes=num_classes).permute(0, 3, 1, 2).float() - target_one_hot = target_one_hot * mask.unsqueeze(1).float() - - intersection = torch.sum(probs * target_one_hot, dim=(2, 3)) - union = torch.sum(probs + target_one_hot, dim=(2, 3)) + target = target.clone() + target[~mask] = 0 + + if num_classes == 1: + target = target.unsqueeze(1) + else: + target = F.one_hot(target, num_classes=num_classes) + target = target.permute(0, 3, 1, 2) + + target = target.float() * mask.unsqueeze(1).float() + intersection = torch.sum(probs * target, dim=(2, 3)) + union = torch.sum(probs + target, dim=(2, 3)) dice_score = (2. * intersection + 1e-6) / (union + 1e-6) # dice_loss = 1 - dice_score.mean(dim=1).mean() From 0891a38a07558217a76d7c0316711b1923abc184 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 16:25:36 +0200 Subject: [PATCH 2/9] added binary flag --- configs/segmentors/reg_upernet_mt.yaml | 1 + configs/segmentors/unet_binary.yaml | 1 + configs/segmentors/unet_cd_binary.yaml | 1 + configs/segmentors/upernet.yaml | 1 + configs/segmentors/upernet_binary.yaml | 1 + configs/segmentors/upernet_cd.yaml | 1 + configs/segmentors/upernet_cd_binary.yaml | 1 + configs/segmentors/upernet_mt.yaml | 1 + 8 files changed, 8 insertions(+) diff --git a/configs/segmentors/reg_upernet_mt.yaml b/configs/segmentors/reg_upernet_mt.yaml index a35de0b7..6802be33 100644 --- a/configs/segmentors/reg_upernet_mt.yaml +++ b/configs/segmentors/reg_upernet_mt.yaml @@ -1,5 +1,6 @@ segmentor_name: MTUPerNet task_name: regression +binary: False multi_temporal_strategy: linear # time_frames: 2 #task_model_args: diff --git a/configs/segmentors/unet_binary.yaml b/configs/segmentors/unet_binary.yaml index e96bfd92..3e75fb98 100644 --- a/configs/segmentors/unet_binary.yaml +++ b/configs/segmentors/unet_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UNet task_name: semantic-segmentation +binary: True # time_frames: 1 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/unet_cd_binary.yaml b/configs/segmentors/unet_cd_binary.yaml index 26b1d5e1..9d178ea3 100644 --- a/configs/segmentors/unet_cd_binary.yaml +++ b/configs/segmentors/unet_cd_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UNetCD task_name: change-detection +binary: True #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet.yaml b/configs/segmentors/upernet.yaml index 135a2f3d..012b765f 100644 --- a/configs/segmentors/upernet.yaml +++ b/configs/segmentors/upernet.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNet task_name: semantic-segmentation +binary: False # time_frames: 2 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/upernet_binary.yaml b/configs/segmentors/upernet_binary.yaml index 5e52c00d..74b76cb4 100644 --- a/configs/segmentors/upernet_binary.yaml +++ b/configs/segmentors/upernet_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNet task_name: semantic-segmentation +binary: True # time_frames: 2 #task_model_args: #num_frames: 1 diff --git a/configs/segmentors/upernet_cd.yaml b/configs/segmentors/upernet_cd.yaml index 6504eb57..12b026da 100644 --- a/configs/segmentors/upernet_cd.yaml +++ b/configs/segmentors/upernet_cd.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNetCD task_name: change-detection +binary: False #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet_cd_binary.yaml b/configs/segmentors/upernet_cd_binary.yaml index 64a28081..0773647c 100644 --- a/configs/segmentors/upernet_cd_binary.yaml +++ b/configs/segmentors/upernet_cd_binary.yaml @@ -1,5 +1,6 @@ segmentor_name: UPerNetCD task_name: change-detection +binary: True #task_model_args: #num_frames: 1 #mt_strategy: "ltae" #activated only when if num_frames > 1 diff --git a/configs/segmentors/upernet_mt.yaml b/configs/segmentors/upernet_mt.yaml index f2e47963..ae5582c0 100644 --- a/configs/segmentors/upernet_mt.yaml +++ b/configs/segmentors/upernet_mt.yaml @@ -1,5 +1,6 @@ segmentor_name: MTUPerNet task_name: semantic-segmentation +binary: False multi_temporal_strategy: linear # time_frames: 2 #task_model_args: From 67930d89c956a81ce140fd029d05bd7d62f1aa64 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 16:25:55 +0200 Subject: [PATCH 3/9] binary support --- engine/trainer.py | 6 +++++- segmentors/unet.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/engine/trainer.py b/engine/trainer.py index f57ef141..2e8255da 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -224,7 +224,11 @@ def compute_loss(self, logits, target): @torch.no_grad() def compute_logging_metrics(self, logits, target): # logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear') - pred = torch.argmax(logits, dim=1, keepdim=True) + num_classes = logits.shape[1] + if num_classes == 1: + pred = torch.sigmoid(logits) > 0.5 + else: + pred = torch.argmax(logits, dim=1, keepdim=True) target = target.unsqueeze(1) ignore_mask = target == -1 target[ignore_mask] = 0 diff --git a/segmentors/unet.py b/segmentors/unet.py index 32bed0fd..abfc8ab0 100644 --- a/segmentors/unet.py +++ b/segmentors/unet.py @@ -25,7 +25,7 @@ def __init__(self, args, cfg, encoder): self.align_corners = False - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] self.topology = encoder.topology self.decoder = Decoder(self.topology) From 5dd219974365bc87e2701dc33e4621ac8916f6a5 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 17:25:13 +0200 Subject: [PATCH 4/9] quick fix for dimensions in logging --- engine/trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/engine/trainer.py b/engine/trainer.py index 2e8255da..193ff731 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -226,16 +226,19 @@ def compute_logging_metrics(self, logits, target): # logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear') num_classes = logits.shape[1] if num_classes == 1: - pred = torch.sigmoid(logits) > 0.5 + pred = (torch.sigmoid(logits) > 0.5).type(torch.int64) else: pred = torch.argmax(logits, dim=1, keepdim=True) target = target.unsqueeze(1) ignore_mask = target == -1 target[ignore_mask] = 0 - ignore_mask = ignore_mask.expand(-1, logits.shape[1], -1, -1) + ignore_mask = ignore_mask.expand(-1, num_classes if num_classes > 1 else 2, -1, -1) - binary_pred = torch.zeros(logits.shape, dtype=bool, device=self.device) - binary_target = torch.zeros(logits.shape, dtype=bool, device=self.device) + dims = list(logits.shape) + if num_classes == 1: + dims[1] = 2 + binary_pred = torch.zeros(dims, dtype=bool, device=self.device) + binary_target = torch.zeros(dims, dtype=bool, device=self.device) binary_pred.scatter_(dim=1, index=pred, src=torch.ones_like(binary_pred)) binary_target.scatter_(dim=1, index=target, src=torch.ones_like(binary_target)) binary_pred[ignore_mask] = 0 From eb90b743cdf72a5a1dfb49bf816aaadc0c505a92 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 18:00:52 +0200 Subject: [PATCH 5/9] turn binary flag off --- configs/segmentors/unet_binary.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/segmentors/unet_binary.yaml b/configs/segmentors/unet_binary.yaml index 3e75fb98..03a48c35 100644 --- a/configs/segmentors/unet_binary.yaml +++ b/configs/segmentors/unet_binary.yaml @@ -1,6 +1,6 @@ segmentor_name: UNet task_name: semantic-segmentation -binary: True +binary: False # time_frames: 1 #task_model_args: #num_frames: 1 From ef5384ce16afe50d6e6363f4b7df4e92a7c79799 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 18:16:35 +0200 Subject: [PATCH 6/9] fixed evaluation for binary problems --- engine/evaluator.py | 5 ++++- engine/trainer.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/engine/evaluator.py b/engine/evaluator.py index 66dd6ba5..a2e6bd4d 100644 --- a/engine/evaluator.py +++ b/engine/evaluator.py @@ -63,7 +63,10 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None): target = target.to(self.device) logits = model(image, output_shape=target.shape[-2:]) - pred = torch.argmax(logits, dim=1) + if logits.shape[1] == 1: + pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1) + else: + pred = torch.argmax(logits, dim=1) valid_mask = target != -1 pred, target = pred[valid_mask], target[valid_mask] count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes ** 2) diff --git a/engine/trainer.py b/engine/trainer.py index 193ff731..470299c5 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -50,9 +50,9 @@ def train(self): #end_time = time.time() for epoch in range(self.start_epoch, self.epochs): # train the network for one epoch - # if epoch % self.args.eval_interval == 0: - # _, used_time = self.evaluator(self.model, f'epoch {epoch}') - # self.training_stats['eval_time'].update(used_time) + if epoch % self.args.eval_interval == 0: + _, used_time = self.evaluator(self.model, f'epoch {epoch}') + self.training_stats['eval_time'].update(used_time) self.logger.info("============ Starting epoch %i ... ============" % epoch) # set sampler From f98edc4a4002d6fedd376f68df1d0124c29a8b71 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 20:37:10 +0200 Subject: [PATCH 7/9] binary flag on --- configs/segmentors/unet_binary.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/segmentors/unet_binary.yaml b/configs/segmentors/unet_binary.yaml index 03a48c35..3e75fb98 100644 --- a/configs/segmentors/unet_binary.yaml +++ b/configs/segmentors/unet_binary.yaml @@ -1,6 +1,6 @@ segmentor_name: UNet task_name: semantic-segmentation -binary: False +binary: True # time_frames: 1 #task_model_args: #num_frames: 1 From f545c4aa90445be2e01da1c2dcd5cd586995ce6f Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 22:06:55 +0200 Subject: [PATCH 8/9] oversampling with random crops --- .../segmentation_oversampling.yaml | 11 ++++ datasets/spacenet7.py | 11 +++- engine/data_preprocessor.py | 52 +++++++++++++++++++ segmentors/unet.py | 2 +- segmentors/upernet.py | 2 +- 5 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 configs/augmentations/segmentation_oversampling.yaml diff --git a/configs/augmentations/segmentation_oversampling.yaml b/configs/augmentations/segmentation_oversampling.yaml new file mode 100644 index 00000000..cf0ee149 --- /dev/null +++ b/configs/augmentations/segmentation_oversampling.yaml @@ -0,0 +1,11 @@ +train: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + ImportanceRandomCropToEncoder: ~ + # RandomFlip: + # ud_probability: 0.3 + # lr_probability: 0.3 +test: + SegPreprocessor: ~ + NormalizeMeanStd: ~ + Tile: ~ \ No newline at end of file diff --git a/datasets/spacenet7.py b/datasets/spacenet7.py index 5a91b221..83efa996 100644 --- a/datasets/spacenet7.py +++ b/datasets/spacenet7.py @@ -109,6 +109,7 @@ def __init__(self, cfg, split): self.data_std = cfg['data_std'] self.classes = cfg['classes'] self.class_num = len(self.classes) + self.distribution = cfg['distribution'] self.split = split if split == 'train': @@ -216,12 +217,16 @@ def __getitem__(self, index): image = torch.from_numpy(image) target = torch.from_numpy(target) + weight = torch.empty(target.shape) + for i, freq in enumerate(self.distribution): + weight[target == i] = 1 - freq output = { 'image': { 'optical': image, }, 'target': target, + 'weight': weight, 'metadata': {} } @@ -243,7 +248,7 @@ def __init__(self, cfg, split, eval_mode): self.T = cfg['multi_temporal'] assert self.T > 1 self.eval_mode = eval_mode - self.items = list(self.aoi_ids) + self.items = 100 * list(self.aoi_ids) def __len__(self): return len(self.items) @@ -275,12 +280,16 @@ def __getitem__(self, index): year_t2, month_t2 = timestamps[-1]['year'], timestamps[-1]['month'] target = self.load_change_label(aoi_id, year_t1, month_t1, year_t2, month_t2) target = torch.from_numpy(target) + weight = torch.empty(target.shape) + for i, freq in enumerate(self.distribution): + weight[target == i] = 1 - freq output = { 'image': { 'optical': image, }, 'target': target, + 'weight': weight, 'metadata': {} } diff --git a/engine/data_preprocessor.py b/engine/data_preprocessor.py index 8ef3885d..9ed5302d 100644 --- a/engine/data_preprocessor.py +++ b/engine/data_preprocessor.py @@ -551,3 +551,55 @@ def __init__(self, dataset, cfg, local_cfg): local_cfg.size = cfg.encoder.input_size super().__init__(dataset, cfg, local_cfg) + + +@AUGMENTER_REGISTRY.register() +class ImportanceRandomCrop(BaseAugment): + def __init__(self, dataset, cfg, local_cfg): + super().__init__(dataset, cfg, local_cfg) + self.size = local_cfg.size + self.padding = getattr(local_cfg, "padding", None) + self.pad_if_needed = getattr(local_cfg, "pad_if_needed", False) + self.fill = getattr(local_cfg, "fill", 0) + self.padding_mode = getattr(local_cfg, "padding_mode", "constant") + self.n_crops = 10 # TODO: put this one in config + + def __getitem__(self, index): + data = self.dataset[index] + + # dataset needs to provide a weighting layer + assert 'weight' in data.keys() + + # candidates for random crop + crop_candidates, crop_weights = [], [] + for _ in range(self.n_crops): + i, j, h, w = T.RandomCrop.get_params( + data["image"][ + list(data["image"].keys())[0] + ], # Use the first image to determine parameters + output_size=(self.size, self.size), + ) + crop_candidates.append((i, j, h, w)) + + crop_weight = T.functional.crop(data['weight'], i, j, h, w) + crop_weights.append(torch.sum(crop_weight).item()) + + crop_weights = np.array(crop_weights) / sum(crop_weights) + crop_idx = np.random.choice(self.n_crops, p=crop_weights) + i, j, h, w = crop_candidates[crop_idx] + + for k, v in data["image"].items(): + if k not in self.ignore_modalities: + data["image"][k] = T.functional.crop(v, i, j, h, w) + data["target"] = T.functional.crop(data["target"], i, j, h, w) + + return data + + +@AUGMENTER_REGISTRY.register() +class ImportanceRandomCropToEncoder(ImportanceRandomCrop): + def __init__(self, dataset, cfg, local_cfg): + if not local_cfg: + local_cfg = omegaconf.OmegaConf.create() + local_cfg.size = cfg.encoder.input_size + super().__init__(dataset, cfg, local_cfg) \ No newline at end of file diff --git a/segmentors/unet.py b/segmentors/unet.py index abfc8ab0..eb5463ef 100644 --- a/segmentors/unet.py +++ b/segmentors/unet.py @@ -57,7 +57,7 @@ def __init__(self, args, cfg, encoder): self.align_corners = False - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] self.topology = encoder.topology self.decoder = Decoder(self.topology) diff --git a/segmentors/upernet.py b/segmentors/upernet.py index 2a1613f4..5312e131 100644 --- a/segmentors/upernet.py +++ b/segmentors/upernet.py @@ -48,7 +48,7 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)): self.in_channels = [cfg['in_channels'] for _ in range(4)] self.channels = cfg['channels'] - self.num_classes = cfg['num_classes'] + self.num_classes = 1 if cfg['binary'] else cfg['num_classes'] # PSP Module self.psp_modules = PPM( From c97b385d22db2d22616f1758651955ff665c9f07 Mon Sep 17 00:00:00 2001 From: SebastianHafner Date: Mon, 9 Sep 2024 22:19:21 +0200 Subject: [PATCH 9/9] fixed multiplier --- datasets/spacenet7.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datasets/spacenet7.py b/datasets/spacenet7.py index 83efa996..f4b82b8d 100644 --- a/datasets/spacenet7.py +++ b/datasets/spacenet7.py @@ -248,7 +248,8 @@ def __init__(self, cfg, split, eval_mode): self.T = cfg['multi_temporal'] assert self.T > 1 self.eval_mode = eval_mode - self.items = 100 * list(self.aoi_ids) + self.multiplier = 1 if eval_mode else 100 # TODO: get this from config + self.items = self.multiplier * list(self.aoi_ids) def __len__(self): return len(self.items)