From b58f69fc3f39aca146df1fb30e93582756868e59 Mon Sep 17 00:00:00 2001 From: valhassan Date: Thu, 14 Nov 2024 13:20:02 -0500 Subject: [PATCH] experimental changes --- .../tasks_with_models/segmentation_dofa.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/geo_deep_learning/tasks_with_models/segmentation_dofa.py b/geo_deep_learning/tasks_with_models/segmentation_dofa.py index 2155c89a..f5c56599 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_dofa.py +++ b/geo_deep_learning/tasks_with_models/segmentation_dofa.py @@ -37,9 +37,13 @@ def __init__(self, self.num_classes = num_classes self.model = DOFASeg(encoder, pretrained, image_size, self.num_classes) self.loss = loss - self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan) + self.iou_metric = MulticlassJaccardIndex(num_classes=num_classes, + average=None, + zero_division=np.nan + ) self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels - self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels) + self.iou_classwise_metric = ClasswiseWrapper(self.iou_metric, + labels=self.labels) self._total_samples_visualized = 0 def forward(self, image: Tensor) -> Tensor: @@ -77,11 +81,9 @@ def test_step(self, batch, batch_idx): y_hat = self(x) loss = self.loss(y_hat, y) y_hat = y_hat.softmax(dim=1).argmax(dim=1) - test_metrics = self.classwise_metric(y_hat, y) - test_metrics["loss"] = loss - self.log_dict(test_metrics, - prog_bar=True, logger=True, - on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True) + self.test_metrics = self.iou_classwise_metric(y_hat, y) + self.test_metrics["loss"] = loss + if self._total_samples_visualized < self.max_samples: remaining_samples = self.max_samples - self._total_samples_visualized num_samples = min(remaining_samples, len(x)) @@ -102,6 +104,9 @@ def test_step(self, batch, batch_idx): self._total_samples_visualized += 1 if self._total_samples_visualized >= self.max_samples: break + def on_test_epoch_end(self): + print(f"test_metrics: {self.test_metrics}") + self.log_dict(self.test_metrics) def on_train_end(self): if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None: