From 0d2600872695fb108188e7428683082aaa2e9dbe Mon Sep 17 00:00:00 2001 From: valhassan Date: Fri, 25 Oct 2024 15:14:03 -0400 Subject: [PATCH] Refactor label to mask in SegmentationSegformer --- .../tasks_with_models/segmentation_segformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/geo_deep_learning/tasks_with_models/segmentation_segformer.py b/geo_deep_learning/tasks_with_models/segmentation_segformer.py index d979905c..9c43a6eb 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_segformer.py +++ b/geo_deep_learning/tasks_with_models/segmentation_segformer.py @@ -32,7 +32,7 @@ def forward(self, image: Tensor) -> Tensor: def training_step(self, batch: Dict[str, Any], batch_idx: int): x = batch["image"] - y = batch["label"] + y = batch["mask"] y = y.squeeze(1).long() y_hat = self(x) loss = self.loss(y_hat, y) @@ -44,7 +44,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int): def validation_step(self, batch, batch_idx): x = batch["image"] - y = batch["label"] + y = batch["mask"] y = y.squeeze(1).long() y_hat = self(x) loss = self.loss(y_hat, y) @@ -56,7 +56,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): x = batch["image"] - y = batch["label"] + y = batch["mask"] y = y.squeeze(1).long() y_hat = self(x) loss = self.loss(y_hat, y)