Skip to content

Commit

Permalink
Refactor SegmentationCallback to improve visualization and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 12, 2024
1 parent f1fb556 commit ac629d4
Showing 1 changed file with 42 additions and 51 deletions.
93 changes: 42 additions & 51 deletions geo_deep_learning/tools/callbacks/segmentation_callback.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.colors import ListedColormap
from lightning.pytorch.callbacks import ModelCheckpoint
from tools.visualization import visualize_prediction
from tools.utils import denormalization

class SegmentationCallback(ModelCheckpoint):

def __init__(self, max_samples: int = 3, class_colors = None, *args, **kwargs):
def __init__(self,
max_samples: int = 3,
mean = None,
std = None,
num_classes = None,
data_type_max = 255,
class_colors = None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_samples = max_samples
self.mean = mean
self.std = std
self.num_classes = num_classes
self.data_type_max = data_type_max
self.class_colors = class_colors
self.current_batch = None
self.current_outputs = None
Expand All @@ -20,13 +34,14 @@ def on_validation_batch_end(self,
batch,
batch_idx,
dataloader_idx=0):
self.current_batch = batch
self.current_outputs = outputs

if trainer.is_global_zero:
self.current_batch = batch
self.current_outputs = outputs

def _save_checkpoint(self, trainer, filepath):
super()._save_checkpoint(trainer, filepath)
self._log_visualizations(trainer)
if trainer.is_global_zero:
self._log_visualizations(trainer, None)

def _setup_colormap(self):
if self.class_colors is None:
Expand All @@ -35,57 +50,33 @@ def _setup_colormap(self):
else:
self.cmap = ListedColormap(self.class_colors)

def _log_visualizations(self, trainer):
def _log_visualizations(self, trainer, batch_idx):
if self.current_batch is not None and self.current_outputs is not None:
image_batch = self.current_batch["image"]
mask_batch = self.current_batch["mask"]
batch_image_name = self.current_batch["image_name"]
batch_mask_name = self.current_batch["mask_name"]
batch_size = image_batch.shape[0]
N = min(self.max_samples, batch_size)
num_classes = mask_batch.max().item() + 1 if self.class_colors is None else len(self.class_colors)

fig, axes = plt.subplots(N, 3, figsize=(15, 5 * N))
axes = axes.reshape(N, 3) if N > 1 else axes.reshape(1, 3)


for i in range(N):
image = image_batch[i]
mask = mask_batch[i]
image_name = batch_image_name[i]
mask_name = batch_mask_name[i]
output = self.current_outputs[i]
image = (image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
mask = mask.squeeze(0).long().cpu().numpy()
output = output.cpu().numpy()
ax_image, ax_mask, ax_output = axes[i]

ax_image.imshow(image)
ax_image.set_title("Image")
ax_image.axis("off")
ax_image.text(0.5, -0.1, f"{image_name}",
transform=ax_image.transAxes,
ha='center', va='top',
wrap=True)

ax_mask.imshow(mask, cmap=self.cmap, vmin=0, vmax=num_classes-1)
ax_mask.set_title("Mask")
ax_mask.axis("off")
# ax_mask.text(0.5, -0.1, f"{mask_name}",
# transform=ax_mask.transAxes,
# ha='center', va='top',
# wrap=True)

ax_output.imshow(output, cmap=self.cmap, vmin=0, vmax=num_classes-1)
ax_output.set_title('Output')
ax_output.axis("off")

plt.tight_layout()
artifact_file = f"val/predictions_epoch_{trainer.current_epoch}.png"
trainer.logger.experiment.log_figure(figure=fig,
artifact_file = artifact_file,
run_id=trainer.logger.run_id)
plt.close(fig)
num_classes = mask_batch.max().item() + 1 if self.num_classes is None else self.num_classes
try:
for i in range(N):
image = image_batch[i]
mask = mask_batch[i]
image_name = batch_image_name[i]
output = self.current_outputs[i]
image = denormalization(image, self.mean, self.std, self.data_type_max)
fig = visualize_prediction(image, mask, output,
sample_name=image_name,
num_classes=num_classes,
class_colors=self.class_colors)
artifact_file = f"val/{Path(image_name).stem}/idx_{i}_epoch_{trainer.current_epoch}.png"
trainer.logger.experiment.log_figure(figure=fig,
artifact_file = artifact_file,
run_id=trainer.logger.run_id)
except Exception as e:
print(f"Error in logging visualizations: {e}")

self.current_batch = None
self.current_outputs = None
finally:
self.current_batch = None
self.current_outputs = None

0 comments on commit ac629d4

Please sign in to comment.