Skip to content

Commit

Permalink
Refactor label to mask in SegmentationSegformer
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Oct 25, 2024
1 parent 55626e4 commit 0d26008
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions geo_deep_learning/tasks_with_models/segmentation_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(self, image: Tensor) -> Tensor:

def training_step(self, batch: Dict[str, Any], batch_idx: int):
x = batch["image"]
y = batch["label"]
y = batch["mask"]
y = y.squeeze(1).long()
y_hat = self(x)
loss = self.loss(y_hat, y)
Expand All @@ -44,7 +44,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int):

def validation_step(self, batch, batch_idx):
x = batch["image"]
y = batch["label"]
y = batch["mask"]
y = y.squeeze(1).long()
y_hat = self(x)
loss = self.loss(y_hat, y)
Expand All @@ -56,7 +56,7 @@ def validation_step(self, batch, batch_idx):

def test_step(self, batch, batch_idx):
x = batch["image"]
y = batch["label"]
y = batch["mask"]
y = y.squeeze(1).long()
y_hat = self(x)
loss = self.loss(y_hat, y)
Expand Down

0 comments on commit 0d26008

Please sign in to comment.