Skip to content

Commit

Permalink
Enhance SegmentationSegformer with improved metrics and visualization
Browse files Browse the repository at this point in the history
- Added support for loading model weights from a specified checkpoint path.
- Updated the metric from MulticlassJaccardIndex to MeanIoU for better segmentation evaluation.
- Introduced visualization of predictions with denormalization and logging of images during testing.
- Added parameters for max_samples, mean, std, and class_colors to enhance model configuration.
- Implemented a warning filter to suppress messages related to changes in grid_sample and affine_grid behavior in kornia.
  • Loading branch information
valhassan committed Nov 27, 2024
1 parent de17eb8 commit bb4fff6
Showing 1 changed file with 60 additions and 12 deletions.
72 changes: 60 additions & 12 deletions geo_deep_learning/tasks_with_models/segmentation_segformer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,57 @@
import warnings
# Ignore warning about default grid_sample and affine_grid behavior triggered by kornia
warnings.filterwarnings("ignore", message="Default grid_sample and affine_grid behavior has changed")
import numpy as np
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from torch import Tensor
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
from lightning.pytorch import LightningModule, LightningDataModule
from torchmetrics.classification import MulticlassJaccardIndex
from torchmetrics.segmentation import MeanIoU
from torchmetrics.wrappers import ClasswiseWrapper
from models.segformer import SegFormer
from tools.script_model import script_model
from tools.utils import denormalization
from tools.visualization import visualize_prediction

class SegmentationSegformer(LightningModule):
def __init__(self,
encoder: str,
in_channels: int,
in_channels: int,
num_classes: int,
max_samples: int,
mean: List[float],
std: List[float],
data_type_max: float,
loss: Callable,
weights: str = None,
class_labels: List[str] = None,
class_colors: List[str] = None,
weights_from_checkpoint_path: Optional[str] = None,
**kwargs: Any):
super().__init__()
self.save_hyperparameters()
self.max_samples = max_samples
self.mean = mean
self.std = std
self.data_type_max = data_type_max
self.class_colors = class_colors
self.num_classes = num_classes
self.model = SegFormer(encoder, in_channels, self.num_classes)
self.model = SegFormer(encoder, in_channels, weights, self.num_classes)
if weights_from_checkpoint_path:
print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}")
checkpoint = torch.load(weights_from_checkpoint_path)
self.load_state_dict(checkpoint['state_dict'])
self.loss = loss
self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan)
self.iou_metric = MeanIoU(num_classes=num_classes,
per_class=True,
input_format="index",
include_background=True
)
self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels
self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels)
self.iou_classwise_metric = ClasswiseWrapper(self.iou_metric, labels=self.labels)
self._total_samples_visualized = 0

def forward(self, image: Tensor) -> Tensor:
return self.model(image)
Expand Down Expand Up @@ -61,11 +87,33 @@ def test_step(self, batch, batch_idx):
y_hat = self(x)
loss = self.loss(y_hat, y)
y_hat = y_hat.softmax(dim=1).argmax(dim=1)
test_metrics = self.classwise_metric(y_hat, y)
test_metrics["loss"] = loss
self.log_dict(test_metrics,
prog_bar=True, logger=True,
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True)
metrics = self.iou_classwise_metric(y_hat, y)
metrics["test_loss"] = loss

if self._total_samples_visualized < self.max_samples:
remaining_samples = self.max_samples - self._total_samples_visualized
num_samples = min(remaining_samples, len(x))
for i in range(num_samples):
image = x[i]
image_name = batch["image_name"][i]
image = denormalization(image, mean=self.mean, std=self.std, data_type_max=self.data_type_max)
fig = visualize_prediction(image,
y[i],
y_hat[i],
image_name,
self.num_classes,
class_colors=self.class_colors)
artifact_file = f"test/{Path(image_name).stem}/idx_{i}.png"
self.logger.experiment.log_figure(figure=fig,
artifact_file=artifact_file,
run_id=self.logger.run_id)
self._total_samples_visualized += 1
if self._total_samples_visualized >= self.max_samples:
break

self.log_dict(metrics,
prog_bar=False, logger=True,
on_step=False, rank_zero_only=True)

def on_train_end(self):
if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None:
Expand All @@ -78,7 +126,7 @@ def on_train_end(self):
self.export_model(best_model_path, best_model_export_path, self.trainer.datamodule)

def export_model(self, checkpoint_path: str, export_path: str, datamodule: LightningDataModule):
input_channels = self.hparams["init_args"]["in_channels"]
input_channels = self.hparams["in_channels"]
map_location = "cuda"
if self.device.type == "cpu":
map_location = "cpu"
Expand Down

0 comments on commit bb4fff6

Please sign in to comment.