Skip to content

Commit

Permalink
Refactor visualization in SegmentationDOFA
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 12, 2024
1 parent 48e0244 commit 84cc762
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions geo_deep_learning/tasks_with_models/segmentation_dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from torchmetrics.classification import MulticlassJaccardIndex
from torchmetrics.wrappers import ClasswiseWrapper
from models.dofa.dofa_seg import DOFASeg
from tools.utils import denormalization
from tools.script_model import script_model
from tools.visualization import visualize_prediction

class SegmentationDOFA(LightningModule):
def __init__(self,
Expand All @@ -17,17 +19,28 @@ def __init__(self,
image_size: tuple[int, int],
in_channels: int,
num_classes: int,
max_samples: int,
mean: List[float],
std: List[float],
data_type_max: float,
loss: Callable,
class_labels: List[str] = None,
class_colors: List[str] = None,
**kwargs: Any):
super().__init__()
self.save_hyperparameters()
self.class_colors = class_colors
self.max_samples = max_samples
self.mean = mean
self.std = std
self.data_type_max = data_type_max
self.num_classes = num_classes
self.model = DOFASeg(encoder, pretrained, image_size, self.num_classes)
self.loss = loss
self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan)
self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels
self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels)
self._total_samples_visualized = 0

def forward(self, image: Tensor) -> Tensor:
return self.model(image)
Expand Down Expand Up @@ -58,6 +71,7 @@ def validation_step(self, batch, batch_idx):

def test_step(self, batch, batch_idx):
x = batch["image"]
image_names = batch["image_name"]
y = batch["mask"]
y = y.squeeze(1).long()
y_hat = self(x)
Expand All @@ -68,6 +82,26 @@ def test_step(self, batch, batch_idx):
self.log_dict(test_metrics,
prog_bar=True, logger=True,
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True)
if self._total_samples_visualized < self.max_samples:
remaining_samples = self.max_samples - self._total_samples_visualized
num_samples = min(remaining_samples, len(x))
for i in range(num_samples):
image = x[i]
image_name = image_names[i]
image = denormalization(image, mean=self.mean, std=self.std, data_type_max=self.data_type_max)
fig = visualize_prediction(image,
y[i],
y_hat[i],
image_name,
self.num_classes,
class_colors=self.class_colors)
artifact_file = f"test/{Path(image_name).stem}/idx_{i}.png"
self.logger.experiment.log_figure(figure=fig,
artifact_file=artifact_file,
run_id=self.logger.run_id)
self._total_samples_visualized += 1
if self._total_samples_visualized >= self.max_samples:
break

def on_train_end(self):
if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None:
Expand Down

0 comments on commit 84cc762

Please sign in to comment.