Skip to content

Commit

Permalink
Refactor label to mask and update visualization in SegmentationCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Oct 25, 2024
1 parent 0d26008 commit f7aa05b
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions geo_deep_learning/tools/callbacks/segmentation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,43 @@ def _setup_colormap(self):
def _log_visualizations(self, trainer):
if self.current_batch is not None and self.current_outputs is not None:
image_batch = self.current_batch["image"]
label_batch = self.current_batch["label"]
mask_batch = self.current_batch["mask"]
batch_image_name = self.current_batch["image_name"]
batch_mask_name = self.current_batch["mask_name"]
batch_size = image_batch.shape[0]
N = min(self.max_samples, batch_size)
num_classes = label_batch.max().item() + 1 if self.class_colors is None else len(self.class_colors)
num_classes = mask_batch.max().item() + 1 if self.class_colors is None else len(self.class_colors)

fig, axes = plt.subplots(N, 3, figsize=(15, 5 * N))
axes = axes.reshape(N, 3) if N > 1 else axes.reshape(1, 3)


for i in range(N):
image = image_batch[i]
label = label_batch[i]
mask = mask_batch[i]
image_name = batch_image_name[i]
mask_name = batch_mask_name[i]
output = self.current_outputs[i]
image = (image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
label = label.squeeze(0).long().cpu().numpy()
mask = mask.squeeze(0).long().cpu().numpy()
output = output.cpu().numpy()
ax_image, ax_label, ax_output = axes[i]
ax_image, ax_mask, ax_output = axes[i]

ax_image.imshow(image)
ax_image.set_title('Image')
ax_image.set_title("Image")
ax_image.axis("off")
ax_image.text(0.5, -0.1, f"{image_name}",
transform=ax_image.transAxes,
ha='center', va='top',
wrap=True)

ax_label.imshow(label, cmap=self.cmap, vmin=0, vmax=num_classes-1)
ax_label.set_title('Label')
ax_label.axis("off")
ax_mask.imshow(mask, cmap=self.cmap, vmin=0, vmax=num_classes-1)
ax_mask.set_title("Mask")
ax_mask.axis("off")
# ax_mask.text(0.5, -0.1, f"{mask_name}",
# transform=ax_mask.transAxes,
# ha='center', va='top',
# wrap=True)

ax_output.imshow(output, cmap=self.cmap, vmin=0, vmax=num_classes-1)
ax_output.set_title('Output')
Expand Down

0 comments on commit f7aa05b

Please sign in to comment.