From bb4fff6a26af0b1c2b2c4030d3e6b825d543c58f Mon Sep 17 00:00:00 2001 From: valhassan Date: Wed, 27 Nov 2024 15:15:57 -0500 Subject: [PATCH] Enhance SegmentationSegformer with improved metrics and visualization - Added support for loading model weights from a specified checkpoint path. - Updated the metric from MulticlassJaccardIndex to MeanIoU for better segmentation evaluation. - Introduced visualization of predictions with denormalization and logging of images during testing. - Added parameters for max_samples, mean, std, and class_colors to enhance model configuration. - Implemented a warning filter to suppress messages related to changes in grid_sample and affine_grid behavior in kornia. --- .../segmentation_segformer.py | 72 +++++++++++++++---- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/geo_deep_learning/tasks_with_models/segmentation_segformer.py b/geo_deep_learning/tasks_with_models/segmentation_segformer.py index 9c43a6eb..aa4d9acb 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_segformer.py +++ b/geo_deep_learning/tasks_with_models/segmentation_segformer.py @@ -1,31 +1,57 @@ +import warnings +# Ignore warning about default grid_sample and affine_grid behavior triggered by kornia +warnings.filterwarnings("ignore", message="Default grid_sample and affine_grid behavior has changed") import numpy as np import torch from pathlib import Path import matplotlib.pyplot as plt from torch import Tensor -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from lightning.pytorch import LightningModule, LightningDataModule -from torchmetrics.classification import MulticlassJaccardIndex +from torchmetrics.segmentation import MeanIoU from torchmetrics.wrappers import ClasswiseWrapper from models.segformer import SegFormer from tools.script_model import script_model +from tools.utils import denormalization +from tools.visualization import visualize_prediction class SegmentationSegformer(LightningModule): def __init__(self, encoder: str, - in_channels: int, + in_channels: int, num_classes: int, + max_samples: int, + mean: List[float], + std: List[float], + data_type_max: float, loss: Callable, + weights: str = None, class_labels: List[str] = None, + class_colors: List[str] = None, + weights_from_checkpoint_path: Optional[str] = None, **kwargs: Any): super().__init__() self.save_hyperparameters() + self.max_samples = max_samples + self.mean = mean + self.std = std + self.data_type_max = data_type_max + self.class_colors = class_colors self.num_classes = num_classes - self.model = SegFormer(encoder, in_channels, self.num_classes) + self.model = SegFormer(encoder, in_channels, weights, self.num_classes) + if weights_from_checkpoint_path: + print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}") + checkpoint = torch.load(weights_from_checkpoint_path) + self.load_state_dict(checkpoint['state_dict']) self.loss = loss - self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan) + self.iou_metric = MeanIoU(num_classes=num_classes, + per_class=True, + input_format="index", + include_background=True + ) self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels - self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels) + self.iou_classwise_metric = ClasswiseWrapper(self.iou_metric, labels=self.labels) + self._total_samples_visualized = 0 def forward(self, image: Tensor) -> Tensor: return self.model(image) @@ -61,11 +87,33 @@ def test_step(self, batch, batch_idx): y_hat = self(x) loss = self.loss(y_hat, y) y_hat = y_hat.softmax(dim=1).argmax(dim=1) - test_metrics = self.classwise_metric(y_hat, y) - test_metrics["loss"] = loss - self.log_dict(test_metrics, - prog_bar=True, logger=True, - on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True) + metrics = self.iou_classwise_metric(y_hat, y) + metrics["test_loss"] = loss + + if self._total_samples_visualized < self.max_samples: + remaining_samples = self.max_samples - self._total_samples_visualized + num_samples = min(remaining_samples, len(x)) + for i in range(num_samples): + image = x[i] + image_name = batch["image_name"][i] + image = denormalization(image, mean=self.mean, std=self.std, data_type_max=self.data_type_max) + fig = visualize_prediction(image, + y[i], + y_hat[i], + image_name, + self.num_classes, + class_colors=self.class_colors) + artifact_file = f"test/{Path(image_name).stem}/idx_{i}.png" + self.logger.experiment.log_figure(figure=fig, + artifact_file=artifact_file, + run_id=self.logger.run_id) + self._total_samples_visualized += 1 + if self._total_samples_visualized >= self.max_samples: + break + + self.log_dict(metrics, + prog_bar=False, logger=True, + on_step=False, rank_zero_only=True) def on_train_end(self): if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None: @@ -78,7 +126,7 @@ def on_train_end(self): self.export_model(best_model_path, best_model_export_path, self.trainer.datamodule) def export_model(self, checkpoint_path: str, export_path: str, datamodule: LightningDataModule): - input_channels = self.hparams["init_args"]["in_channels"] + input_channels = self.hparams["in_channels"] map_location = "cuda" if self.device.type == "cpu": map_location = "cpu"