Skip to content

Commit

Permalink
experimental changes
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 14, 2024
1 parent ac629d4 commit b58f69f
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions geo_deep_learning/tasks_with_models/segmentation_dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ def __init__(self,
self.num_classes = num_classes
self.model = DOFASeg(encoder, pretrained, image_size, self.num_classes)
self.loss = loss
self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan)
self.iou_metric = MulticlassJaccardIndex(num_classes=num_classes,
average=None,
zero_division=np.nan
)
self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels
self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels)
self.iou_classwise_metric = ClasswiseWrapper(self.iou_metric,
labels=self.labels)
self._total_samples_visualized = 0

def forward(self, image: Tensor) -> Tensor:
Expand Down Expand Up @@ -77,11 +81,9 @@ def test_step(self, batch, batch_idx):
y_hat = self(x)
loss = self.loss(y_hat, y)
y_hat = y_hat.softmax(dim=1).argmax(dim=1)
test_metrics = self.classwise_metric(y_hat, y)
test_metrics["loss"] = loss
self.log_dict(test_metrics,
prog_bar=True, logger=True,
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True)
self.test_metrics = self.iou_classwise_metric(y_hat, y)
self.test_metrics["loss"] = loss

if self._total_samples_visualized < self.max_samples:
remaining_samples = self.max_samples - self._total_samples_visualized
num_samples = min(remaining_samples, len(x))
Expand All @@ -102,6 +104,9 @@ def test_step(self, batch, batch_idx):
self._total_samples_visualized += 1
if self._total_samples_visualized >= self.max_samples:
break
def on_test_epoch_end(self):
print(f"test_metrics: {self.test_metrics}")
self.log_dict(self.test_metrics)

def on_train_end(self):
if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None:
Expand Down

0 comments on commit b58f69f

Please sign in to comment.