Skip to content

Commit

Permalink
Merge pull request #29 from yurujaja/binary_improvements
Browse files Browse the repository at this point in the history
Binary improvements
  • Loading branch information
EricBrune authored Sep 10, 2024
2 parents 5cc3bf9 + bfdc858 commit 00f1816
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 20 deletions.
11 changes: 11 additions & 0 deletions configs/augmentations/segmentation_oversampling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
train:
SegPreprocessor: ~
NormalizeMeanStd: ~
ImportanceRandomCropToEncoder: ~
# RandomFlip:
# ud_probability: 0.3
# lr_probability: 0.3
test:
SegPreprocessor: ~
NormalizeMeanStd: ~
Tile: ~
1 change: 1 addition & 0 deletions configs/segmentors/reg_upernet_mt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: MTUPerNet
task_name: regression
binary: False
multi_temporal_strategy: linear
# time_frames: 2
#task_model_args:
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/unet_binary.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UNet
task_name: semantic-segmentation
binary: True
# time_frames: 1
#task_model_args:
#num_frames: 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/unet_cd_binary.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UNetCD
task_name: change-detection
binary: True
#task_model_args:
#num_frames: 1
#mt_strategy: "ltae" #activated only when if num_frames > 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/upernet.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UPerNet
task_name: semantic-segmentation
binary: False
# time_frames: 2
#task_model_args:
#num_frames: 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/upernet_binary.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UPerNet
task_name: semantic-segmentation
binary: True
# time_frames: 2
#task_model_args:
#num_frames: 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/upernet_cd.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UPerNetCD
task_name: change-detection
binary: False
#task_model_args:
#num_frames: 1
#mt_strategy: "ltae" #activated only when if num_frames > 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/upernet_cd_binary.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: UPerNetCD
task_name: change-detection
binary: True
#task_model_args:
#num_frames: 1
#mt_strategy: "ltae" #activated only when if num_frames > 1
Expand Down
1 change: 1 addition & 0 deletions configs/segmentors/upernet_mt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
segmentor_name: MTUPerNet
task_name: semantic-segmentation
binary: False
multi_temporal_strategy: linear
# time_frames: 2
#task_model_args:
Expand Down
12 changes: 11 additions & 1 deletion datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, cfg, split):
self.data_std = cfg['data_std']
self.classes = cfg['classes']
self.class_num = len(self.classes)
self.distribution = cfg['distribution']
self.split = split

if split == 'train':
Expand Down Expand Up @@ -216,12 +217,16 @@ def __getitem__(self, index):

image = torch.from_numpy(image)
target = torch.from_numpy(target)
weight = torch.empty(target.shape)
for i, freq in enumerate(self.distribution):
weight[target == i] = 1 - freq

output = {
'image': {
'optical': image,
},
'target': target,
'weight': weight,
'metadata': {}
}

Expand All @@ -243,7 +248,8 @@ def __init__(self, cfg, split, eval_mode):
self.T = cfg['multi_temporal']
assert self.T > 1
self.eval_mode = eval_mode
self.items = list(self.aoi_ids)
self.multiplier = 1 if eval_mode else 100 # TODO: get this from config
self.items = self.multiplier * list(self.aoi_ids)

def __len__(self):
return len(self.items)
Expand Down Expand Up @@ -275,12 +281,16 @@ def __getitem__(self, index):
year_t2, month_t2 = timestamps[-1]['year'], timestamps[-1]['month']
target = self.load_change_label(aoi_id, year_t1, month_t1, year_t2, month_t2)
target = torch.from_numpy(target)
weight = torch.empty(target.shape)
for i, freq in enumerate(self.distribution):
weight[target == i] = 1 - freq

output = {
'image': {
'optical': image,
},
'target': target,
'weight': weight,
'metadata': {}
}

Expand Down
52 changes: 52 additions & 0 deletions engine/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,3 +546,55 @@ def __init__(self, dataset, cfg, local_cfg):
local_cfg.size = cfg.encoder.input_size
super().__init__(dataset, cfg, local_cfg)



@AUGMENTER_REGISTRY.register()
class ImportanceRandomCrop(BaseAugment):
def __init__(self, dataset, cfg, local_cfg):
super().__init__(dataset, cfg, local_cfg)
self.size = local_cfg.size
self.padding = getattr(local_cfg, "padding", None)
self.pad_if_needed = getattr(local_cfg, "pad_if_needed", False)
self.fill = getattr(local_cfg, "fill", 0)
self.padding_mode = getattr(local_cfg, "padding_mode", "constant")
self.n_crops = 10 # TODO: put this one in config

def __getitem__(self, index):
data = self.dataset[index]

# dataset needs to provide a weighting layer
assert 'weight' in data.keys()

# candidates for random crop
crop_candidates, crop_weights = [], []
for _ in range(self.n_crops):
i, j, h, w = T.RandomCrop.get_params(
data["image"][
list(data["image"].keys())[0]
], # Use the first image to determine parameters
output_size=(self.size, self.size),
)
crop_candidates.append((i, j, h, w))

crop_weight = T.functional.crop(data['weight'], i, j, h, w)
crop_weights.append(torch.sum(crop_weight).item())

crop_weights = np.array(crop_weights) / sum(crop_weights)
crop_idx = np.random.choice(self.n_crops, p=crop_weights)
i, j, h, w = crop_candidates[crop_idx]

for k, v in data["image"].items():
if k not in self.ignore_modalities:
data["image"][k] = T.functional.crop(v, i, j, h, w)
data["target"] = T.functional.crop(data["target"], i, j, h, w)

return data


@AUGMENTER_REGISTRY.register()
class ImportanceRandomCropToEncoder(ImportanceRandomCrop):
def __init__(self, dataset, cfg, local_cfg):
if not local_cfg:
local_cfg = omegaconf.OmegaConf.create()
local_cfg.size = cfg.encoder.input_size
super().__init__(dataset, cfg, local_cfg)
5 changes: 4 additions & 1 deletion engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None):
target = target.to(self.device)

logits = model(image, output_shape=target.shape[-2:])
pred = torch.argmax(logits, dim=1)
if logits.shape[1] == 1:
pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1)
else:
pred = torch.argmax(logits, dim=1)
valid_mask = target != -1
pred, target = pred[valid_mask], target[valid_mask]
count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes ** 2)
Expand Down
15 changes: 11 additions & 4 deletions engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,21 @@ def compute_loss(self, logits, target):
@torch.no_grad()
def compute_logging_metrics(self, logits, target):
# logits = F.interpolate(logits, size=target.shape[1:], mode='bilinear')
pred = torch.argmax(logits, dim=1, keepdim=True)
num_classes = logits.shape[1]
if num_classes == 1:
pred = (torch.sigmoid(logits) > 0.5).type(torch.int64)
else:
pred = torch.argmax(logits, dim=1, keepdim=True)
target = target.unsqueeze(1)
ignore_mask = target == -1
target[ignore_mask] = 0
ignore_mask = ignore_mask.expand(-1, logits.shape[1], -1, -1)
ignore_mask = ignore_mask.expand(-1, num_classes if num_classes > 1 else 2, -1, -1)

binary_pred = torch.zeros(logits.shape, dtype=bool, device=self.device)
binary_target = torch.zeros(logits.shape, dtype=bool, device=self.device)
dims = list(logits.shape)
if num_classes == 1:
dims[1] = 2
binary_pred = torch.zeros(dims, dtype=bool, device=self.device)
binary_target = torch.zeros(dims, dtype=bool, device=self.device)
binary_pred.scatter_(dim=1, index=pred, src=torch.ones_like(binary_pred))
binary_target.scatter_(dim=1, index=target, src=torch.ones_like(binary_target))
binary_pred[ignore_mask] = 0
Expand Down
4 changes: 2 additions & 2 deletions segmentors/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, args, cfg, encoder):

self.align_corners = False

self.num_classes = cfg['num_classes']
self.num_classes = 1 if cfg['binary'] else cfg['num_classes']
self.topology = encoder.topology

self.decoder = Decoder(self.topology)
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, args, cfg, encoder):

self.align_corners = False

self.num_classes = cfg['num_classes']
self.num_classes = 1 if cfg['binary'] else cfg['num_classes']
self.topology = encoder.topology

self.decoder = Decoder(self.topology)
Expand Down
2 changes: 1 addition & 1 deletion segmentors/upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, args, cfg, encoder, pool_scales=(1, 2, 3, 6)):

self.in_channels = [cfg['in_channels'] for _ in range(4)]
self.channels = cfg['channels']
self.num_classes = cfg['num_classes']
self.num_classes = 1 if cfg['binary'] else cfg['num_classes']

# PSP Module
self.psp_modules = PPM(
Expand Down
30 changes: 19 additions & 11 deletions utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,28 @@ def __init__(self, cfg):
self.ignore_index = cfg["ignore_index"]

def forward(self, logits, target):
# Convert logits to probabilities using softmax
probs = F.softmax(logits, dim=1)

num_classes = logits.shape[1]

# Convert logits to probabilities using softmax or sigmoid
if num_classes == 1:
probs = torch.sigmoid(logits)
else:
probs = F.softmax(logits, dim=1)

mask = (target != self.ignore_index)
#mask_expand = mask.unsqueeze(1).expand_as(probs)
target_temp = target.clone()
target_temp[~mask] = 0

target_one_hot = F.one_hot(target_temp, num_classes=num_classes).permute(0, 3, 1, 2).float()
target_one_hot = target_one_hot * mask.unsqueeze(1).float()

intersection = torch.sum(probs * target_one_hot, dim=(2, 3))
union = torch.sum(probs + target_one_hot, dim=(2, 3))
target = target.clone()
target[~mask] = 0

if num_classes == 1:
target = target.unsqueeze(1)
else:
target = F.one_hot(target, num_classes=num_classes)
target = target.permute(0, 3, 1, 2)

target = target.float() * mask.unsqueeze(1).float()
intersection = torch.sum(probs * target, dim=(2, 3))
union = torch.sum(probs + target, dim=(2, 3))

dice_score = (2. * intersection + 1e-6) / (union + 1e-6)
# dice_loss = 1 - dice_score.mean(dim=1).mean()
Expand Down

0 comments on commit 00f1816

Please sign in to comment.